当前位置:   article > 正文

知识图谱嵌入:TransE算法原理及代码详解_transe代码

transe代码

目录

KGE

TransE

TransE代码详解


KGE

知识图谱中,离散符号化的知识不能够进行语义计算,为帮助计算机对知识进行计算,解决数据稀疏性,可以将知识图谱中的实体、关系映射到低维连续的向量空间中,这类方法称为知识图谱嵌入(Knowledge Graph Embedding, KGE)。

TransE

受到词向量中平移不变性的启发,TransE将关系的向量表示解释成头、尾实体向量之间的转移向量,算法简单而高效。并且在模型训练过程中,可以学习到一定的语义信息。其基本思想是,如果一个三元组(h, l, t)为真,那么向量空间中对应向量需要符合h + l ≈ t。例如:

vec(Rome) + vec(is-capital-of) ≈ vec(Italy)

vec(Paris) + vec(is-capital-of) ≈ vec(France)

TransE-平移距离

据此可以对缺失的三元组(Beijing,is-capital-of,?)、(Beijing,?,China)、(?,is-capital-of,China)进行补全,即链接预测。

TransE是最早的翻译模型,后面还推出了TransD、TransR、TransH、TransA等等,换汤不换药,主要是对TransE进行改进和补充。

优点:

能够解决数据稀疏的难题,提升知识计算的效率。

能够自动捕捉推理特征,无须人工设计。

算法简单,学习的参数少,计算复杂度低。

缺点:

无法有效处理一对多、多对一、多对多、自反等复杂关系。

仅考虑一跳关系,忽略了长距离的隐关系。

嵌入模型不能快速收敛。

伪代码:

输入:训练集S = \left \{ (h,l,t)\right \},实体集E,关系集L,margin值γ,嵌入向量维度k

1:初始化    对于每个关系向量l\in L  ←  从(-\frac{6}{\sqrt{k}},\frac{6}{\sqrt{k}})区间内随机采样

2:               对于每个关系向量l\in L  ←  除以自身的L2范数

3:               对于每个实体向量e\in E  ←  从(-\frac{6}{\sqrt{k}},\frac{6}{\sqrt{k}})区间内随机采样

4:循环:

5:               对于每个实体向量e\in E  ←  除以自身的L2范数

6:               从训练集S中取出数量为b的样本作为一个S_{batch}

7:               初始化三元组集合T_{batch}为一个空列表

8:               遍历(h,l,t)\in S_{batch},执行

9:                            替换正确三元组的头实体或者尾实体构造负样本(h^{'},l,t)(h,l,t^{'})

10:                          将正样本三元组和负样本三元组都放在T_{batch}列表中

11:             遍历结束

12:            根据梯度下降更新实体、关系向量

13:循环结束


TransE代码详解

1、加载数据

传入训练集S = \left \{ (h,l,t)\right \},实体集E,关系集L这三个数据文件的地址

返回三个列表:实体,关系,三元组。(其中实体、关系都以id表示)

  1. import codecs
  2. import numpy as np
  3. import copy
  4. import time
  5. import random
  6. def dataloader(file1, file2, file3):
  7. print("load file...")
  8. entity = []
  9. relation = []
  10. entities2id = {}
  11. relations2id = {}
  12. with open(file2, 'r') as f1, open(file3, 'r') as f2:
  13. lines1 = f1.readlines()
  14. lines2 = f2.readlines()
  15. for line in lines1:
  16. line = line.strip().split('\t')
  17. if len(line) != 2:
  18. continue
  19. entities2id[line[0]] = line[1]
  20. entity.append(line[1])
  21. for line in lines2:
  22. line = line.strip().split('\t')
  23. if len(line) != 2:
  24. continue
  25. relations2id[line[0]] = line[1]
  26. relation.append(line[1])
  27. triple_list = []
  28. with codecs.open(file1, 'r') as f:
  29. content = f.readlines()
  30. for line in content:
  31. triple = line.strip().split("\t")
  32. if len(triple) != 3:
  33. continue
  34. h_ = entities2id[triple[0]]
  35. r_ = relations2id[triple[1]]
  36. t_ = entities2id[triple[2]]
  37. triple_list.append([h_, r_, t_])
  38. print("Complete load. entity : %d , relation : %d , triple : %d" % (
  39. len(entity), len(relation), len(triple_list)))
  40. return entity, relation, triple_list

2、传参

传入实体id列表entity,关系id列表relation,三元组列表triple_list,向量维度embedding_dim=50,学习率lr=0.01,margin(正负样本三元组之间的间隔修正),norm范数,loss损失值。

  1. class TransE:
  2. def __init__(self, entity, relation, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=1):
  3. self.entities = entity
  4. self.relations = relation
  5. self.triples = triple_list
  6. self.dimension = embedding_dim
  7. self.learning_rate = lr
  8. self.margin = margin
  9. self.norm = norm
  10. self.loss = 0.0

3、初始化

即伪代码中的步骤1-3。

将实体id列表、关系id列表转变为{实体id:实体向量}、{关系id:关系向量}这两个字典。

  1. class TransE:
  2. def data_initialise(self):
  3. entityVectorList = {}
  4. relationVectorList = {}
  5. for entity in self.entities:
  6. entity_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)
  7. entityVectorList[entity] = entity_vector
  8. for relation in self.relations:
  9. relation_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)
  10. relation_vector = self.normalization(relation_vector)
  11. relationVectorList[relation] = relation_vector
  12. self.entities = entityVectorList
  13. self.relations = relationVectorList
  14. def normalization(self, vector):
  15. return vector / np.linalg.norm(vector)

4、训练过程

即伪代码中的步骤4-13。

nbatches=100,即数据集分为100个batch依次训练,每个batch的样本数量即batch_size。epochs=1,即完整跑完100个batch的次数。

首先对实体向量进行归一化。

对于每一个batch,随机采样batch_size数量的三元组作为S_{batch},即代码中的batch_samples。

初始化三元组集合T_{batch}为一个空列表。

对于batch_samples中的每一个样本,随机替换头实体或者尾实体生成负样本三元组。

其中,while corrupted_sample[0] == sample[0]是一个过滤正样本三元组的过程,避免从实体集中采样的实体仍是原实体。不过,此处严格来说应使用while corrupted_sample in self.triples,防止采样的实体h2虽然不是原实体h1,但该三元组仍是正样本(即(h1,l,t)和(h2,l,t)都在三元组列表中,都成立)。但是这句代码需要遍历整个三元组列表,会使训练时间增加10倍,故将其简化。

将正样本和负样本三元组都放入T_{batch}列表中。

调用update_triple_embedding函数,计算这一个batch的损失值,根据梯度下降法更新向量,然后再进行下一个batch的训练。

所有的100个batch训练完成后,将训练好的实体向量、关系向量输出到out_file_title目录下(为空,代表保存在当前目录)

  1. class TransE:
  2. def training_run(self, epochs=1, nbatches=100, out_file_title = ''):
  3. batch_size = int(len(self.triples) / nbatches)
  4. print("batch size: ", batch_size)
  5. for epoch in range(epochs):
  6. start = time.time()
  7. self.loss = 0.0
  8. # Normalise the embedding of the entities to 1
  9. for entity in self.entities.keys():
  10. self.entities[entity] = self.normalization(self.entities[entity]);
  11. for batch in range(nbatches):
  12. batch_samples = random.sample(self.triples, batch_size)
  13. Tbatch = []
  14. for sample in batch_samples:
  15. corrupted_sample = copy.deepcopy(sample)
  16. pr = np.random.random(1)[0]
  17. if pr > 0.5:
  18. # change the head entity
  19. corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]
  20. while corrupted_sample[0] == sample[0]:
  21. corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]
  22. else:
  23. # change the tail entity
  24. corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]
  25. while corrupted_sample[2] == sample[2]:
  26. corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]
  27. if (sample, corrupted_sample) not in Tbatch:
  28. Tbatch.append((sample, corrupted_sample))
  29. self.update_triple_embedding(Tbatch)
  30. end = time.time()
  31. print("epoch: ", epoch, "cost time: %s" % (round((end - start), 3)))
  32. print("running loss: ", self.loss)
  33. with codecs.open(out_file_title +"TransE_entity_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f1:
  34. for e in self.entities.keys():
  35. f1.write(e + "\t")
  36. f1.write(str(list(self.entities[e])))
  37. f1.write("\n")
  38. with codecs.open(out_file_title +"TransE_relation_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f2:
  39. for r in self.relations.keys():
  40. f2.write(r + "\t")
  41. f2.write(str(list(self.relations[r])))
  42. f2.write("\n")

5、梯度下降

首先调用deepcopy函数深拷贝实体和关系向量,取出实体和关系id分别对应的向量,根据L1范数或L2范数计算得分函数。

L1范数计算得分:np.sum(np.fabs(h + r - t))

L2范数计算得分:np.sum(np.square(h + r - t))

再根据以下公式计算损失值loss:( \gamma 即margin值)

L2范数根据以下公式计算梯度:

L1范数的梯度向量中每个元素为-1或1。

最后根据梯度对实体、关系向量进行更新和归一化。

  1. class TransE:
  2. def update_triple_embedding(self, Tbatch):
  3. copy_entity = copy.deepcopy(self.entities)
  4. copy_relation = copy.deepcopy(self.relations)
  5. for correct_sample, corrupted_sample in Tbatch:
  6. correct_copy_head = copy_entity[correct_sample[0]]
  7. correct_copy_tail = copy_entity[correct_sample[2]]
  8. relation_copy = copy_relation[correct_sample[1]]
  9. corrupted_copy_head = copy_entity[corrupted_sample[0]]
  10. corrupted_copy_tail = copy_entity[corrupted_sample[2]]
  11. correct_head = self.entities[correct_sample[0]]
  12. correct_tail = self.entities[correct_sample[2]]
  13. relation = self.relations[correct_sample[1]]
  14. corrupted_head = self.entities[corrupted_sample[0]]
  15. corrupted_tail = self.entities[corrupted_sample[2]]
  16. # calculate the distance of the triples
  17. if self.norm == 1:
  18. correct_distance = norm_l1(correct_head, relation, correct_tail)
  19. corrupted_distance = norm_l1(corrupted_head, relation, corrupted_tail)
  20. else:
  21. correct_distance = norm_l2(correct_head, relation, correct_tail)
  22. corrupted_distance = norm_l2(corrupted_head, relation, corrupted_tail)
  23. loss = self.margin + correct_distance - corrupted_distance
  24. if loss > 0:
  25. self.loss += loss
  26. correct_gradient = 2 * (correct_head + relation - correct_tail)
  27. corrupted_gradient = 2 * (corrupted_head + relation - corrupted_tail)
  28. if self.norm == 1:
  29. for i in range(len(correct_gradient)):
  30. if correct_gradient[i] > 0:
  31. correct_gradient[i] = 1
  32. else:
  33. correct_gradient[i] = -1
  34. if corrupted_gradient[i] > 0:
  35. corrupted_gradient[i] = 1
  36. else:
  37. corrupted_gradient[i] = -1
  38. correct_copy_head -= self.learning_rate * correct_gradient
  39. relation_copy -= self.learning_rate * correct_gradient
  40. correct_copy_tail -= -1 * self.learning_rate * correct_gradient
  41. relation_copy -= -1 * self.learning_rate * corrupted_gradient
  42. if correct_sample[0] == corrupted_sample[0]:
  43. # if corrupted_triples replaces the tail entity, the head entity's embedding need to be updated twice
  44. correct_copy_head -= -1 * self.learning_rate * corrupted_gradient
  45. corrupted_copy_tail -= self.learning_rate * corrupted_gradient
  46. elif correct_sample[2] == corrupted_sample[2]:
  47. # if corrupted_triples replaces the head entity, the tail entity's embedding need to be updated twice
  48. corrupted_copy_head -= -1 * self.learning_rate * corrupted_gradient
  49. correct_copy_tail -= self.learning_rate * corrupted_gradient
  50. # normalising these new embedding vector, instead of normalising all the embedding together
  51. copy_entity[correct_sample[0]] = self.normalization(correct_copy_head)
  52. copy_entity[correct_sample[2]] = self.normalization(correct_copy_tail)
  53. if correct_sample[0] == corrupted_sample[0]:
  54. # if corrupted_triples replace the tail entity, update the tail entity's embedding
  55. copy_entity[corrupted_sample[2]] = self.normalization(corrupted_copy_tail)
  56. elif correct_sample[2] == corrupted_sample[2]:
  57. # if corrupted_triples replace the head entity, update the head entity's embedding
  58. copy_entity[corrupted_sample[0]] = self.normalization(corrupted_copy_head)
  59. # the paper mention that the relation's embedding don't need to be normalised
  60. copy_relation[correct_sample[1]] = relation_copy
  61. # copy_relation[correct_sample[1]] = self.normalization(relation_copy)
  62. self.entities = copy_entity
  63. self.relations = copy_relation

6、__main__

  1. if __name__ == '__main__':
  2. # file1 = "FB15k\\train.txt"
  3. # file2 = "FB15k\\entity2id.txt"
  4. # file3 = "FB15k\\relation2id.txt"
  5. file1 = "WN18\\wordnet-mlj12-train.txt"
  6. file2 = "WN18\\entity2id.txt"
  7. file3 = "WN18\\relation2id.txt"
  8. entity_set, relation_set, triple_list = dataloader(file1, file2, file3)
  9. transE = TransE(entity_set, relation_set, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=2)
  10. transE.data_initialise()
  11. transE.training_run(out_file_title="WN18_")

参考:

代码来自于:论文笔记(一):TransE论文详解及代码复现 - 知乎,点击完整代码可下载代码。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/549564
推荐阅读
  

闽ICP备14008679号