当前位置:   article > 正文

TransE,知识图谱嵌入(KGE)源码阅读(一)_transe,知识图谱嵌入(kge)论文精读

transe,知识图谱嵌入(kge)论文精读

TransE,知识图谱嵌入(KGE)源码阅读(一)


Paper: Antoine Bordes等人在2013年发表于NIPS上的文章

Paper Understanding:TransE,知识图谱嵌入(KGE)论文精读

Algorithm:
在这里插入图片描述
OpenKE:Ubuntu 20.04子系统中使用OpenKE进行复现,详情点击此处进入查看


TransE 算法训练代码和阅读笔记

理解时参考了博客:知识表示学习 TransE 代码逻辑梳理 超详细解析,但加入了自己更详细的注解,可以互为参照

import codecs
import numpy as np
import copy
import time
import random


"""
定义两个字典,用于查找实体和关系名对应的ID,数据格式:关系名/实体名:ID
"""
entities2id = {}
relations2id = {}


def dataloader(file1, file2, file3):
    """
    加载数据集FB15k中的三个文件 或 WN18中的文件
    :param file1: 三元组,文件的每条记录是:头实体名字,关系名字,尾实体名字
    :param file2: 实体集,文件的每条记录是:实体名字,实体ID
    :param file3: 关系集,文件的每条记录是:关系名字,关系ID
    :return:
    """
    print("load file...")
    entity = []
    relation = []
    with open(file2, 'r') as f1, open(file3, 'r') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            entities2id[line[0]] = line[1]
            entity.append(line[1])

        """
        共lines2行,进行逐行扫描读取,直到扫描完成,结束循环
        """
        for line in lines2:
            line = line.strip().split('\t')
            """
            不等于2,说明扫描到最后一行,结束本轮,进行下一轮,line++
            """
            if len(line) != 2:
                continue
            """
            字典中存放的数据形式是:关系名:关系ID
            entity中存放关系ID
            """
            relations2id[line[0]] = line[1]
            relation.append(line[1])

    triple_list = []

    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            # 去掉首尾空格,以空格作为分隔符,将分割后元素放入数组中
            triple = line.strip().split("\t")
            """
            不等于3,说明扫描到最后一行,结束本轮,进行下一轮,line++
            """
            if len(triple) != 3:
                continue
            """
            从字典中找到三元组中实体和关系名对应的id
            """
            h_ = entities2id[triple[0]]
            r_ = relations2id[triple[1]]
            t_ = entities2id[triple[2]]
            """
            列表triple_list中存放的也都是各个三元组,但是它们的存在是以ID的形式
            """
            triple_list.append([h_, r_, t_])

    print("Complete load. entity : %d , relation : %d , triple : %d" % (
        len(entity), len(relation), len(triple_list)))

    """
    返回三个列表,里边存放的都是id
    """
    return entity, relation, triple_list



"""
定义L1范数和L2范数,用于打分函数的规范
"""
def norm_l1(h, r, t):
    return np.sum(np.fabs(h + r - t))


def norm_l2(h, r, t):
    return np.sum(np.square(h + r - t))


class TransE:
    def __init__(self, entity, relation, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=1):
        self.entities = entity
        self.relations = relation
        self.triples = triple_list
        self.dimension = embedding_dim
        self.learning_rate = lr
        self.margin = margin
        self.norm = norm
        self.loss = 0.0

    def data_initialize(self):
        """
        初始化向量:构建字典集合,来存放实体向量和关系向量
        :return:
        """

        entityVectorList = {}
        relationVectorList = {}


        # 对每个实体生成一个 dimension 维的向量,这个向量由一个列表表示
        # 列表中的每一个元素在(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)之间

        for entity in self.entities:
            entity_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),
                                              self.dimension)

            # 对每个实体随机生成的向量,赋予给各个实体,当然这里不用实体名,而是向量与实体的ID对应,放入字典中
            entityVectorList[entity] = entity_vector

        for relation in self.relations:
            relation_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),
                                                self.dimension)
            relation_vector = self.normalization(relation_vector)
            relationVectorList[relation] = relation_vector

        """
        最后,self.entities,经向量初始化后由列表entities变成了字典entityVectorList,格式:{实体名:对应向量,...}
        """
        self.entities = entityVectorList
        """
        最后,self.relations,经向量初始化后由列表relations变成了字典relationVectorList,格式:{关系名:对应向量,...}
        """
        self.relations = relationVectorList

    """
    np.linalg.norm(vector),没有指定,则默认是二范数,即对矩阵或向量每个元素平方和开平方,这里是向量的模
    对向量生成作归一化处理,向量除以一个系数,这个系数是"每个元素除以元素总和的平方和的开平方"
    这样,每个向量归一化为单位向量
    """

    def normalization(self, vector):
        return vector / np.linalg.norm(vector)

    def training_run(self, epochs=1, nbatches=100, out_file_title=''):
        """
        回顾:
        当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个 epoch,然而,当一个 epoch 对于计算机而言太庞大的时候,就需要把它分成多个小块,即batch(数据集在不能将数据一次性通过神经网络的时候,就需要将数据集分成几个batch),训练集共141442个三元组,nbatches,即batch的个数为100,那么batch_size为1414,
        为什么要使用多于一个 epoch?
        因为在神经网络中传递完整的数据集一次是不够的,而且我们需要将完整的数据集在同样的神经网络中传递多次。
        但是,我们使用的是有限的数据集,并且我们使用一个迭代过程即梯度下降,随着 epoch 数量增加,神经网络中的权重的更新次数也增加,曲线从欠拟合变得过拟合
        迭代指的是完成一个epoch的batch的个数
        :param epochs:数据集共投放几次
        :param nbatches:一个epoch分成100轮投放,nbatches=100
        :param out_file_title:用于拼接成文件名的字符串
        :return:
        """
        batch_size = int(len(self.triples) / nbatches)
        print("batch size: ", batch_size)

        # range(1)即产生[0],循环一次,epoch等于0
        # 对应算法中的loop部分
        for epoch in range(epochs):
            start = time.time()
            self.loss = 0.0
            # entities是字典,其键组成的列表,是实体名列表,取到对应向量,然后norm化后,再放入原元组entities
            for entity in self.entities.keys():
                self.entities[entity] = self.normalization(self.entities[entity]);

            # 对nbatches轮训练的每一轮,都随机采样batch_size大小的三元组集合sample(S,b)作为Sbatch
            for batch in range(nbatches):
                batch_samples = random.sample(self.triples, batch_size)
                # 初始化三元组集合Tbatch为空列表,两个成对的三元组构成一个元组,作为列表中的一个元素。存放一个epoch中所有的成对构造
                Tbatch = []
                for sample in batch_samples:
                    # 从sbatch词典中采集负样本三元组,来构造负样本三元组(h',r',t'),深拷贝,这样可以在改变的基础上不影响batch_samples
                    # 拷贝后的corrupted_sample仍然是列表,其中的每个元素也是列表【h,r,t】
                    corrupted_sample = copy.deepcopy(sample)
                    # 随机改变头实体或尾实体(只能有其中一个)
                    pr = np.random.random(1)[0]
                    if pr > 0.5:
                        # 替换头实体
                        # 在原有的实体字典中所有键(实体名)构成的列表中,随机选取1个,得到一个实体名构成的列表,取出来
                        # 赋值给这个三元组的头实体,由此改变三元组,构建负样本三元组
                        corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]
                        # 因为替换后,三元组仍然有可能是正样本三元组,所以这里是过滤的过程,论文中提到的filtered
                        # 但我认为这里有些敷衍,只是避免了替换后的三元组不和原来的三元组相同,但并不能保证此三元组是否为正例,比如甲和乙都出生在南阳
                        # 而且即使再重新替换,仍有可能还随机到刚刚那个,和放回抽样是一样的
                        while corrupted_sample[0] == sample[0]:
                            corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]
                    else:
                        # 替换尾实体,这个过程与上述是一样的
                        corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]
                        while corrupted_sample[2] == sample[2]:
                            corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]

                    # 如果这个三元组((h,r,t),(h’,r’,t’))不在Tbatch中,就将其加入
                    if (sample, corrupted_sample) not in Tbatch:
                        Tbatch.append((sample, corrupted_sample))
                # 向下调用函数更新,计算损失函数,这就是整个算法过程,一次更新就是一个BATCH中所有embedding的值,然后继续下个batch的训练
                self.update_triple_embedding(Tbatch)
            end = time.time()
            # 一个epoch(100个batch投送,每个batch块大小为1414)结束,直到141400个三元组的负例全部构造完成,输出时间
            print("epoch: ", epochs)
            print("cost time: %.2f" % (end - start))
            print("running loss: ", self.loss)


        # 向着损失函数最小化的方向,用梯度下降法调整更新参数。并将结果写入存储在WN18_TransE_entity_50dim_batch1414和WN18_TransE_relation_50dim_batch1414中
        # 将entities写入新建的文件WN18_TransE_entity_50dim_batch1414,其中写入的数据格式(每一行): 实体名  对应向量
        with codecs.open(out_file_title + "TransE_entity_" + str(self.dimension) + "dim_batch" + str(batch_size),
                         "w") as f1:

            for e in self.entities.keys():
                f1.write(e + "\t")
                f1.write(str(list(self.entities[e])))
                f1.write("\n")

        with codecs.open(out_file_title + "TransE_relation_" + str(self.dimension) + "dim_batch" + str(batch_size),
                         "w") as f2:
            for r in self.relations.keys():
                f2.write(r + "\t")
                f2.write(str(list(self.relations[r])))
                f2.write("\n")

    def update_triple_embedding(self, Tbatch):
        # deepcopy 可以保证,即使list嵌套list也能让各层的地址不同, 即这里copy_entity和entities中所有的elements都不同

        # 实体名 对应向量
        copy_entity = copy.deepcopy(self.entities)
        # 关系名 对应向量
        copy_relation = copy.deepcopy(self.relations)

        for correct_sample, corrupted_sample in Tbatch:

            # 拷贝的目的是为了更新用,更新后重新赋给entity,这样当前的(下边的)能计算本轮的损失函数,决定更新的方向
            correct_copy_head = copy_entity[correct_sample[0]]
            correct_copy_tail = copy_entity[correct_sample[2]]
            relation_copy = copy_relation[correct_sample[1]]
            corrupted_copy_head = copy_entity[corrupted_sample[0]]
            corrupted_copy_tail = copy_entity[corrupted_sample[2]]

            # 取到正例的头实体对应的向量,因为triples词典构造出来Tbatch中三元组的数据格式都是实体名,关系名,我们要找对应的头尾实体和关系向量做计算
            correct_head = self.entities[correct_sample[0]]
            correct_tail = self.entities[correct_sample[2]]
            relation = self.relations[correct_sample[1]]
            corrupted_head = self.entities[corrupted_sample[0]]
            corrupted_tail = self.entities[corrupted_sample[2]]

            # 计算打分函数
            if self.norm == 1:
                correct_distance = norm_l1(correct_head, relation, correct_tail)
                corrupted_distance = norm_l1(corrupted_head, relation, corrupted_tail)

            else:
                correct_distance = norm_l2(correct_head, relation, correct_tail)
                corrupted_distance = norm_l2(corrupted_head, relation, corrupted_tail)

            # 计算损失函数
            # 如果正例特别小,负例特别大,分类器很好分,这时,loss几乎为0或小于0,这时候取0,不做叠加
            loss = self.margin + correct_distance - corrupted_distance
            # 否则,计算每个batch的损失并做叠加
            if loss > 0:
                self.loss += loss

                # 损失函数对head求梯度的绝对值就可以(默认二范数)
                correct_gradient = 2 * (correct_head + relation - correct_tail)
                corrupted_gradient = 2 * (corrupted_head + relation - corrupted_tail)
                # 如果传入的是一范数,梯度就是如下
                if self.norm == 1:
                    for i in range(len(correct_gradient)):
                        if correct_gradient[i] > 0:
                            correct_gradient[i] = 1
                        else:
                            correct_gradient[i] = -1

                        if corrupted_gradient[i] > 0:
                            corrupted_gradient[i] = 1
                        else:
                            corrupted_gradient[i] = -1

                # 更新正例的头尾实体向量
                correct_copy_head -= self.learning_rate * correct_gradient
                relation_copy -= self.learning_rate * correct_gradient
                # 尾实体向量更新,梯度本身是负数
                correct_copy_tail -= -1 * self.learning_rate * correct_gradient

                relation_copy -= -1 * self.learning_rate * corrupted_gradient
                if correct_sample[0] == corrupted_sample[0]:
                    # 如果当时随机替换的是尾实体(因为头实体相同),负例尾实体向量要再次进行更新。同时,正例头实体向量需要再更新一次,即重叠的实体更新两次,否则就会导致后一次更新覆盖前一次
                    # 更新使得正例的打分函数越来越小,负例越来越大,这样损失越接近0,能达到相对好的效果,梯度统一按负例的梯度下降
                    correct_copy_head -= -1 * self.learning_rate * corrupted_gradient
                    corrupted_copy_tail -= self.learning_rate * corrupted_gradient
                elif correct_sample[2] == corrupted_sample[2]:
                    # 如果当时随机替换的是头实体(因为尾实体相同),负例头实体向量要再次进行更新。同时,正例尾实体向量需要再更新一次,即重叠的实体再更新一次,否则就会导致后一次更新覆盖前一次
                    corrupted_copy_head -= -1 * self.learning_rate * corrupted_gradient
                    correct_copy_tail -= self.learning_rate * corrupted_gradient

                # 对正例的头尾实体向量进行归一化处理
                copy_entity[correct_sample[0]] = self.normalization(correct_copy_head)
                copy_entity[correct_sample[2]] = self.normalization(correct_copy_tail)

                # 对负例头尾实体进行归一化,还要对应上述的情况
                if correct_sample[0] == corrupted_sample[0]:
                    # 如果当时替换的头实体,归一化负例的尾实体向量
                    copy_entity[corrupted_sample[2]] = self.normalization(corrupted_copy_tail)
                elif correct_sample[2] == corrupted_sample[2]:
                    # 如果当时替换的尾实体,归一化负例的头实体向量
                    copy_entity[corrupted_sample[0]] = self.normalization(corrupted_copy_head)
                # the paper mention that the relation's embedding don't need to be normalised
                copy_relation[correct_sample[1]] = relation_copy
                # copy_relation[correct_sample[1]] = self.normalization(relation_copy)

        # 这个更新就是一个替换的过程,拷贝一份给他,他更新后还给你,这样也不影响自身求当前的损失函数和梯度
        # 这样更新后,又可以重新进行负采样,进行新一个batch的训练了
        self.entities = copy_entity
        self.relations = copy_relation


if __name__ == '__main__':
    # file1 = "FB15k\\train.txt"
    # file2 = "FB15k\\entity2id.txt"
    # file3 = "FB15k\\relation2id.txt"

    file1 = "WN18\\wordnet-mlj12-train.txt"
    file2 = "WN18\\entity2id.txt"
    file3 = "WN18\\relation2id.txt"

    
    # 读取三个文件中的数据,返回三元组集、实体集和关系集
    entity_set, relation_set, triple_list = dataloader(file1, file2, file3)
    # 实例化TransE类对象,传递参数
    transE = TransE(entity_set, relation_set, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=2)
    # 数据初始化
    transE.data_initialize()
    # 训练,传递参数WN18_,用于写入文件时,拼接形成文件名
    transE.training_run(out_file_title="WN18_")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345

欢迎评论区纠错和讨论!希望自己走的弯路,能让大家避开,更多疑问联系我,QQ:743337163

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/712682
推荐阅读
相关标签
  

闽ICP备14008679号