当前位置:   article > 正文

TransE测试代码_transe代码

transe代码
  1. import numpy as np
  2. import codecs
  3. import operator
  4. import json
  5. from transE import data_loader,entity2id,relation2id
  6. def dataloader(entity_file,relation_file,test_file):
  7. # entity_file: entity \t embedding
  8. entity_dict = {}
  9. relation_dict = {}
  10. test_triple = []
  11. with codecs.open(entity_file) as e_f:
  12. lines = e_f.readlines()
  13. for line in lines:
  14. entity,embedding = line.strip().split('\t')
  15. embedding = json.loads(embedding)
  16. entity_dict[entity] = embedding
  17. with codecs.open(relation_file) as r_f:
  18. lines = r_f.readlines()
  19. for line in lines:
  20. relation,embedding = line.strip().split('\t')
  21. embedding = json.loads(embedding)
  22. relation_dict[relation] = embedding
  23. with codecs.open(test_file) as t_f:
  24. lines = t_f.readlines()
  25. for line in lines:
  26. triple = line.strip().split('\t')
  27. if len(triple) != 3:
  28. continue
  29. h_ = entity2id[triple[0]]
  30. t_ = entity2id[triple[1]]
  31. r_ = relation2id[triple[2]]
  32. test_triple.append(tuple((h_,t_,r_)))
  33. return entity_dict,relation_dict,test_triple
  34. def distance(h,r,t):
  35. h = np.array(h)
  36. r=np.array(r)
  37. t = np.array(t)
  38. s=h+r-t
  39. return np.linalg.norm(s)
  40. class Test:
  41. def __init__(self,entity_dict,relation_dict,test_triple,train_triple,isFit = True):
  42. self.entity_dict = entity_dict
  43. self.relation_dict = relation_dict
  44. self.test_triple = test_triple
  45. self.train_triple = train_triple
  46. self.isFit = isFit
  47. self.hits10 = 0
  48. self.mean_rank = 0
  49. self.relation_hits10 = 0
  50. self.relation_mean_rank = 0
  51. def rank(self):
  52. hits = 0
  53. rank_sum = 0
  54. step = 1
  55. for triple in self.test_triple:
  56. rank_head_dict = {}
  57. rank_tail_dict = {}
  58. for entity in self.entity_dict.keys():
  59. corrupted_head = [entity,triple[1],triple[2]]
  60. if self.isFit:
  61. if corrupted_head not in self.train_triple:
  62. h_emb = self.entity_dict[corrupted_head[0]]
  63. r_emb = self.relation_dict[corrupted_head[2]]
  64. t_emb = self.entity_dict[corrupted_head[1]]
  65. rank_head_dict[tuple(corrupted_head)]=distance(h_emb,r_emb,t_emb)
  66. else:
  67. h_emb = self.entity_dict[corrupted_head[0]]
  68. r_emb = self.relation_dict[corrupted_head[2]]
  69. t_emb = self.entity_dict[corrupted_head[1]]
  70. rank_head_dict[tuple(corrupted_head)] = distance(h_emb, r_emb, t_emb)
  71. corrupted_tail = [triple[0],entity,triple[2]]
  72. if self.isFit:
  73. if corrupted_tail not in self.train_triple:
  74. h_emb = self.entity_dict[corrupted_tail[0]]
  75. r_emb = self.relation_dict[corrupted_tail[2]]
  76. t_emb = self.entity_dict[corrupted_tail[1]]
  77. rank_tail_dict[tuple(corrupted_tail)] = distance(h_emb, r_emb, t_emb)
  78. else:
  79. h_emb = self.entity_dict[corrupted_tail[0]]
  80. r_emb = self.relation_dict[corrupted_tail[2]]
  81. t_emb = self.entity_dict[corrupted_tail[1]]
  82. rank_tail_dict[tuple(corrupted_tail)] = distance(h_emb, r_emb, t_emb)
  83. rank_head_sorted = sorted(rank_head_dict.items(),key = operator.itemgetter(1))
  84. rank_tail_sorted = sorted(rank_tail_dict.items(),key = operator.itemgetter(1))
  85. #rank_sum and hits
  86. for i in range(len(rank_head_sorted)):
  87. if triple[0] == rank_head_sorted[i][0][0]:
  88. if i<10:
  89. hits += 1
  90. rank_sum = rank_sum + i + 1
  91. break
  92. for i in range(len(rank_tail_sorted)):
  93. if triple[1] == rank_tail_sorted[i][0][1]:
  94. if i<10:
  95. hits += 1
  96. rank_sum = rank_sum + i + 1
  97. break
  98. step += 1
  99. if step % 5000 == 0:
  100. print("step ", step, " ,hits ",hits," ,rank_sum ",rank_sum)
  101. print()
  102. self.hits10 = hits / (2*len(self.test_triple))
  103. self.mean_rank = rank_sum / (2*len(self.test_triple))
  104. def relation_rank(self):
  105. hits = 0
  106. rank_sum = 0
  107. step = 1
  108. for triple in self.test_triple:
  109. rank_dict = {}
  110. for r in self.relation_dict.keys():
  111. corrupted_relation = (triple[0],triple[1],r)
  112. if self.isFit and corrupted_relation in self.train_triple:
  113. continue
  114. h_emb = self.entity_dict[corrupted_relation[0]]
  115. r_emb = self.relation_dict[corrupted_relation[2]]
  116. t_emb = self.entity_dict[corrupted_relation[1]]
  117. rank_dict[r]=distance(h_emb, r_emb, t_emb)
  118. rank_sorted = sorted(rank_dict.items(),key = operator.itemgetter(1))
  119. rank = 1
  120. for i in rank_sorted:
  121. if triple[2] == i[0]:
  122. break
  123. rank += 1
  124. if rank<10:
  125. hits += 1
  126. rank_sum = rank_sum + rank + 1
  127. step += 1
  128. if step % 5000 == 0:
  129. print("relation step ", step, " ,hits ", hits, " ,rank_sum ", rank_sum)
  130. print()
  131. self.relation_hits10 = hits / len(self.test_triple)
  132. self.relation_mean_rank = rank_sum / len(self.test_triple)
  133. if __name__ == '__main__':
  134. _, _, train_triple = data_loader("FB15k\\")
  135. entity_dict, relation_dict, test_triple = \
  136. dataloader("entity_50dim_batch400","relation50dim_batch400",
  137. "FB15k\\test.txt")
  138. test = Test(entity_dict,relation_dict,test_triple,train_triple,isFit=False)
  139. test.rank()
  140. print("entity hits@10: ", test.hits10)
  141. print("entity meanrank: ", test.mean_rank)
  142. test.relation_rank()
  143. print("relation hits@10: ", test.relation_hits10)
  144. print("relation meanrank: ", test.relation_mean_rank)
  145. f = open("result.txt",'w')
  146. f.write("entity hits@10: "+ str(test.hits10) + '\n')
  147. f.write("entity meanrank: " + str(test.mean_rank) + '\n')
  148. f.write("relation hits@10: " + str(test.relation_hits10) + '\n')
  149. f.write("relation meanrank: " + str(test.relation_mean_rank) + '\n')
  150. f.close()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/549571
推荐阅读
  

闽ICP备14008679号