当前位置:   article > 正文

Java版本TransE代码的学习_transe伪代码

transe伪代码

参考资料

Anery/transE: transE算法 简单python实现 FB15k (github.com)

Translating Embeddings for Modeling Multi-relational Data (nips.cc)

输入

1.数据集S

2.Entities集合E

3.Relations集合L

4.margin hyperparameter γ

5.每个向量的长度 k

初始化

 

为Entities集合中的每个实体以及Relation集合中的实体,初始化一个向量,并对初始化的向量进行L2范数归一化

其中,相同的entity在头或者在尾出现,都是使用相同的向量

  1. for (int i = 0; i < relation_num; i++) {
  2. for (int j = 0; j < vector_len; j++) {
  3. relation_vec[i][j] = uniform(-6 / sqrt(vector_len), 6 / sqrt(vector_len));
  4. }
  5. }
  6. for (int i = 0; i < entity_num; i++) {
  7. for (int j = 0; j < vector_len; j++) {
  8. entity_vec[i][j] = uniform(-6 / sqrt(vector_len), 6 / sqrt(vector_len));//初始化所有的数据组合都有一个向量
  9. }
  10. norm(entity_vec[i], vector_len);
  11. }
  1. static double uniform(double min, double max) {
  2. // generate a float number which is in [min, max), refer to the Python uniform
  3. return min + (max - min) * Math.random();
  4. }

用梯度下降更新每个初始化的向量

1.第6行表示,每次迭代,都从数据集中随机抽出大小为b的数据,为Sbatch

2.第7到10行表示,替换头或者尾生成错误的数据,正确和错误的数据是Tbatch中的一个子集

3.根据损失值,分别更新h的向量,l的向量,t的向量,错误的h‘或者错误的t’向量

其中,代码中的更新顺序会和伪代码有略微的差别。另外一个问题是代码的终止条件是按循环的次数,但实际上论文当中写的是按照验证集的预测效果来终止迭代

  1. for (int epoch = 0; epoch < nepoch; epoch++) {
  2. res = 0; // means the total loss in each epoch
  3. for (int batch = 0; batch < nbatches; batch++) {
  4. for (int k = 0; k < batchsize; k++) {
  5. int i = rand_max(fb_h.size());//第i条数据
  6. int j = rand_max(entity_num);//生成一个随机的节点,第j个节点
  7. int relation_id = fb_r.get(i);//第i条数据的relation
  8. double pr = 1000 * right_num.get(relation_id) / (right_num.get(relation_id) + left_num.get(relation_id));//随机选择
  9. if (method == 0) {
  10. pr = 500;
  11. }
  12. if (rand() % 1000 < pr) {
  13. Pair<Integer, Integer> key = new Pair<>(fb_h.get(i), fb_r.get(i));
  14. Set<Integer> values = head_relation2tail.get(key); // 获取头实体和关系对应的尾实体集合
  15. while (values.contains(j)) {
  16. j = rand_max(entity_num);//这个随机节点需要是一个错误的数值,生成尾巴是错误的数据
  17. }
  18. res += train_kb(fb_h.get(i), fb_l.get(i), fb_r.get(i), j, fb_l.get(i), fb_r.get(i), res);
  19. } else {
  20. Pair<Integer, Integer> key = new Pair<>(j, fb_r.get(i));//生成头是错误的数据
  21. Set<Integer> values = head_relation2tail.get(key);
  22. if (values != null) {
  23. while (values.contains(fb_l.get(i))) {
  24. j = rand_max(entity_num);
  25. key = new Pair<>(j, fb_r.get(i));
  26. values = head_relation2tail.get(key);
  27. if (values == null) break;
  28. }
  29. }
  30. res += train_kb(fb_h.get(i), fb_l.get(i), fb_r.get(i), j, fb_l.get(i), fb_r.get(i), res);
  31. }
  32. norm(relation_vec[fb_r.get(i)], vector_len);//归一化
  33. norm(entity_vec[fb_h.get(i)], vector_len);//归一化
  34. norm(entity_vec[fb_l.get(i)], vector_len);//归一化
  35. norm(entity_vec[j], vector_len);//归一化
  36. }
  37. }
  38. System.out.printf("epoch: %s %s\n", epoch, res);
  39. }

生成一个随机数

根据数值决定生成错误的头还是错误的尾

  1. double pr = 1000 * right_num.get(relation_id) / (right_num.get(relation_id) + left_num.get(relation_id));//随机选择
  2. if (method == 0) {
  3. pr = 500;
  4. }
  1. if (method == 0) {
  2. pr = 500;
  3. }
  4. if (rand() % 1000 < pr) {

生成错误的triple

其中这里生成的数据是原数据集中没有的

生成错误的尾实体数据

  1. if (rand() % 1000 < pr) {
  2. Pair<Integer, Integer> key = new Pair<>(fb_h.get(i), fb_r.get(i));
  3. Set<Integer> values = head_relation2tail.get(key); // 获取头实体和关系对应的尾实体集合
  4. while (values.contains(j)) {
  5. j = rand_max(entity_num);//这个随机节点需要是一个错误的数值,生成尾巴是错误的数据
  6. }
  7. res += train_kb(fb_h.get(i), fb_l.get(i), fb_r.get(i), fb_h.get(i), j, fb_r.get(i), res);
  8. } }

其中,代码文件中有一个错误的地方,j传递的位置出现了问题

      res += train_kb(fb_h.get(i), fb_l.get(i), fb_r.get(i), fb_h.get(i), j, fb_r.get(i),

生成错误的头实体数据

  1. Pair<Integer, Integer> key = new Pair<>(j, fb_r.get(i));//生成头是错误的数据
  2. Set<Integer> values = head_relation2tail.get(key);
  3. if (values != null) {
  4. while (values.contains(fb_l.get(i))) {
  5. j = rand_max(entity_num);
  6. key = new Pair<>(j, fb_r.get(i));
  7. values = head_relation2tail.get(key);
  8. if (values == null) break;
  9. }
  10. }

计算损失值

只有损失值为正数的时候,才执行embedding的更新

 

  1. static double train_kb(int head_a, int tail_a, int relation_a, int head_b, int tail_b, int relation_b, double res) {
  2. // 极大似然估计的计算过程
  3. double sum1 = calc_sum(head_a, tail_a, relation_a);
  4. double sum2 = calc_sum(head_b, tail_b, relation_b);
  5. if (sum1 + margin > sum2) { {

计算向量距离,即上面的sum1和sum2

 

两种计算方式,一种是用绝对值,另一种是开方

其中,sum是对每一位的差值进行累加

  1. static double calc_sum(int e1, int e2, int rel) {
  2. // 计算头实体、关系与尾实体之间的向量距离
  3. double sum = 0;
  4. if (L1_flag) {
  5. for (int i = 0; i < vector_len; i++) {
  6. sum += abs(entity_vec[e2][i] - entity_vec[e1][i] - relation_vec[rel][i]);
  7. }
  8. } else {
  9. for (int i = 0; i < vector_len; i++) {
  10. sum += sqr(entity_vec[e2][i] - entity_vec[e1][i] - relation_vec[rel][i]);
  11. }
  12. }
  13. return sum;
  14. }

更新embedding

对正确的embedding中的向量head,relation,tail执行梯度下降

对错误的embedding中的向量head,relation,tail执行梯度下降

gradient(head_a, tail_a, relation_a, head_b, tail_b, relation_b);
  1. static void gradient(int head_a, int tail_a, int relation_a, int head_b, int tail_b, int relation_b) {
  2. for (int i = 0; i < vector_len; i++) {
  3. double delta1 = entity_vec[tail_a][i] - entity_vec[head_a][i] - relation_vec[relation_a][i];
  4. double delta2 = entity_vec[tail_b][i] - entity_vec[head_b][i] - relation_vec[relation_b][i];
  5. double x;
  6. if (L1_flag) {
  7. if (delta1 > 0) {
  8. x = 1;
  9. } else {
  10. x = -1;
  11. }
  12. relation_vec[relation_a][i] += x * learning_rate;
  13. entity_vec[head_a][i] += x * learning_rate;
  14. entity_vec[tail_a][i] -= x * learning_rate;
  15. if (delta2 > 0) {
  16. x = 1;
  17. } else {
  18. x = -1;
  19. }
  20. relation_vec[relation_b][i] -= x * learning_rate;
  21. entity_vec[head_b][i] -= x * learning_rate;
  22. entity_vec[tail_b][i] += x * learning_rate;
  23. } else {
  24. delta1 = abs(delta1);
  25. delta2 = abs(delta2);
  26. relation_vec[relation_a][i] += learning_rate * 2 * delta1;
  27. entity_vec[head_a][i] += learning_rate * 2 * delta1;
  28. entity_vec[tail_a][i] -= learning_rate * 2 * delta1;
  29. relation_vec[relation_b][i] -= learning_rate * 2 * delta2;
  30. entity_vec[head_b][i] -= learning_rate * 2 * delta2;
  31. entity_vec[tail_b][i] += learning_rate * 2 * delta2;
  32. }
  33. }
  34. }

对更新后的向量,执行归一化

  1. norm(relation_vec[fb_r.get(i)], vector_len);//归一化
  2. norm(entity_vec[fb_h.get(i)], vector_len);//归一化
  3. norm(entity_vec[fb_l.get(i)], vector_len);//归一化
  4. norm(entity_vec[j], vector_len);//归一化

完成预测

  1. public void run() throws IOException {
  2. relation_vec = new double[relation_num][vector_len];
  3. entity_vec = new double[entity_num][vector_len];
  4. Read_Vec_File("resource/result/relation2vec.bern", relation_vec);
  5. Read_Vec_File("resource/result/entity2vec.bern", entity_vec);
  6. int head_meanRank_raw = 0, tail_meanRank_raw = 0, head_meanRank_filter = 0, tail_meanRank_filter = 0; // 在正确三元组之前的匹配距离之和
  7. int head_hits10 = 0, tail_hits10 = 0, head_hits10_filter = 0, tail_hits10_filter = 0; // 在正确三元组之前的匹配个数之和
  8. int relation_meanRank_raw = 0, relation_meanRank_filter = 0;
  9. int relation_hits10 = 0, relation_hits10_filter = 0;
  10. // ------------------------ evaluation link predict ----------------------------------------
  11. System.out.printf("Total test triple = %s\n", fb_l.size());
  12. System.out.printf("The evaluation of link predict\n");
  13. for (int id = 0; id < fb_l.size(); id++) {
  14. int head = fb_h.get(id);
  15. int tail = fb_l.get(id);
  16. int relation = fb_r.get(id);
  17. List<Pair<Integer, Double>> head_dist = new ArrayList<>();//预测头
  18. for (int i = 0; i < entity_num; i++) {
  19. double sum = calc_sum(i, tail, relation);//计算所有组合的距离
  20. head_dist.add(new Pair<>(i, sum));
  21. }
  22. Collections.sort(head_dist, (o1, o2) -> Double.compare(o1.b, o2.b));//对headlist排序
  23. int filter = 0; // 统计匹配过程已有的正确三元组个数
  24. for (int i = 0; i < head_dist.size(); i++) {
  25. int cur_head = head_dist.get(i).a;
  26. if (hrt_isvalid(cur_head, relation, tail)) { // 如果当前三元组是正确三元组,则记录到filter中
  27. filter += 1;
  28. }
  29. if (cur_head == head) {
  30. head_meanRank_raw += i; // 统计小于<h, l, r>距离的数量
  31. head_meanRank_filter += i - filter;
  32. if (i <= 10) {
  33. head_hits10++;//不过滤的结果
  34. }
  35. if (i - filter <= 10) {//去掉在数据集中存在,但不是想要的结果数据
  36. head_hits10_filter++;//过滤的结果
  37. }
  38. break;
  39. }
  40. }
  41. filter = 0;
  42. List<Pair<Integer, Double>> tail_dist = new ArrayList<>();//预测尾巴
  43. for (int i = 0; i < entity_num; i++) {
  44. double sum = calc_sum(head, i, relation);
  45. tail_dist.add(new Pair<>(i, sum));
  46. }
  47. Collections.sort(tail_dist, (o1, o2) -> Double.compare(o1.b, o2.b));
  48. for (int i = 0; i < tail_dist.size(); i++) {
  49. int cur_tail = tail_dist.get(i).a;
  50. if (hrt_isvalid(head, relation, cur_tail)) {
  51. filter++;
  52. }
  53. if (cur_tail == tail) {
  54. tail_meanRank_raw += i;
  55. tail_meanRank_filter += i - filter;
  56. if (i <= 10) {
  57. tail_hits10++;
  58. }
  59. if (i - filter <= 10) {
  60. tail_hits10_filter++;
  61. }
  62. break;
  63. }
  64. }
  65. }
  66. System.out.printf("-----head prediction------\n");
  67. System.out.printf("Raw MeanRank: %.3f, Filter MeanRank: %.3f\n",
  68. (head_meanRank_raw * 1.0) / fb_l.size(), (head_meanRank_filter * 1.0) / fb_l.size());
  69. System.out.printf("Raw Hits@10: %.3f, Filter Hits@10: %.3f\n",
  70. (head_hits10 * 1.0) / fb_l.size(), (head_hits10_filter * 1.0) / fb_l.size());
  71. System.out.printf("-----tail prediction------\n");
  72. System.out.printf("Raw MeanRank: %.3f, Filter MeanRank: %.3f\n",
  73. (tail_meanRank_raw * 1.0) / fb_l.size(), (tail_meanRank_filter * 1.0) / fb_l.size());
  74. System.out.printf("Raw Hits@10: %.3f, Filter Hits@10: %.3f\n",
  75. (tail_hits10 * 1.0) / fb_l.size(), (tail_hits10_filter * 1.0) / fb_l.size());
  76. // ------------------------ evaluation relation-linked predict ----------------------------------------
  77. int relation_hits = 5; // 选取hits@5为评价指标
  78. for (int id = 0; id < fb_l.size(); id++) {
  79. int head = fb_h.get(id);
  80. int tail = fb_l.get(id);
  81. int relation = fb_r.get(id);
  82. List<Pair<Integer, Double>> relation_dist = new ArrayList<>();
  83. for (int i = 0; i < relation_num; i++) {
  84. double sum = calc_sum(head, tail, i);
  85. relation_dist.add(new Pair<>(i, sum));
  86. }
  87. Collections.sort(relation_dist, (o1, o2) -> Double.compare(o1.b, o2.b));
  88. int filter = 0; // 统计匹配过程已有的正确三元组个数
  89. for (int i = 0; i < relation_dist.size(); i++) {
  90. int cur_relation = relation_dist.get(i).a;
  91. if (hrt_isvalid(head, cur_relation, tail)) { // 如果当前三元组是正确三元组,则记录到filter中
  92. filter += 1;
  93. }
  94. if (cur_relation == relation) {
  95. relation_meanRank_raw += i; // 统计小于<h, l, r>距离的数量
  96. relation_meanRank_filter += i - filter;
  97. if (i <= 5) {
  98. relation_hits10++;
  99. }
  100. if (i - filter <= 5) {
  101. relation_hits10_filter++;
  102. }
  103. break;
  104. }
  105. }
  106. }
  107. System.out.printf("-----relation prediction------\n");
  108. System.out.printf("Raw MeanRank: %.3f, Filter MeanRank: %.3f\n",
  109. (relation_meanRank_raw * 1.0) / fb_r.size(), (relation_meanRank_filter * 1.0) / fb_r.size());
  110. System.out.printf("Raw Hits@%d: %.3f, Filter Hits@%d: %.3f\n",
  111. relation_hits, (relation_hits10 * 1.0) / fb_r.size(),
  112. relation_hits, (relation_hits10_filter * 1.0) / fb_r.size());
  113. }

首先读取训练集中的文件

读取relation的向量,entity的向量

  1. Read_Vec_File("resource/result/relation2vec.bern", relation_vec);
  2. Read_Vec_File("resource/result/entity2vec.bern", entity_vec);

以预测head为例

每一条数据为一个测试样例

1.计算每一个答案的距离(分数)

2.对答案降序排名

3.统计过滤的结果,以及不过滤的结果

过滤的结果,在数据集当中有,满足head的条件,但是与这条数据中的head不相同,在计算排名的时候将这些答案过滤掉

  1. // ------------------------ evaluation link predict ----------------------------------------
  2. System.out.printf("Total test triple = %s\n", fb_l.size());
  3. System.out.printf("The evaluation of link predict\n");
  4. for (int id = 0; id < fb_l.size(); id++) {
  5. int head = fb_h.get(id);
  6. int tail = fb_l.get(id);
  7. int relation = fb_r.get(id);
  8. List<Pair<Integer, Double>> head_dist = new ArrayList<>();//预测头
  9. for (int i = 0; i < entity_num; i++) {
  10. double sum = calc_sum(i, tail, relation);//计算所有组合的距离
  11. head_dist.add(new Pair<>(i, sum));
  12. }
  13. Collections.sort(head_dist, (o1, o2) -> Double.compare(o1.b, o2.b));//对headlist排序
  14. int filter = 0; // 统计匹配过程已有的正确三元组个数
  15. for (int i = 0; i < head_dist.size(); i++) {
  16. int cur_head = head_dist.get(i).a;
  17. if (hrt_isvalid(cur_head, relation, tail)) { // 如果当前三元组是正确三元组,则记录到filter中
  18. filter += 1;
  19. }
  20. if (cur_head == head) {
  21. head_meanRank_raw += i; // 统计小于<h, l, r>距离的数量
  22. head_meanRank_filter += i - filter;
  23. if (i <= 10) {
  24. head_hits10++;//不过滤的结果
  25. }
  26. if (i - filter <= 10) {//去掉在数据集中存在,但不是想要的结果数据
  27. head_hits10_filter++;//过滤的结果
  28. }
  29. break;
  30. }
  31. }
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/549503
推荐阅读
相关标签
  

闽ICP备14008679号