当前位置:   article > 正文

transe 简单代码实现_transe实现

transe实现

用于对知识图谱中的实体、关系基于TransE算法训练获取向量
结果为:两个文本文件,即entityVector.txt和relationVector.txt
但是数据集没办法上传,如果有需要联系我哦。

# -*- coding: utf-8 -*-
"""
@description: 增加了对代码的一些注解
"""
from random import uniform, sample
from numpy import *
from copy import deepcopy

class TransE:
    def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
        """
        目标函数的常数——margin
        学习率——learningRate
        向量维度——dim
        实体列表——entityList(读取文本文件,实体+id)
        关系列表——relationList(读取文本文件,关系 + id)
        三元关系列表——tripleList(读取文本文件,实体 + 实体 + 关系)
        损失值——loss
        距离公式——L1        
        """
        self.margin = margin
        self.learingRate = learingRate
        self.dim = dim#向量维度
        self.entityList = entityList#一开始,entityList是entity的list;
        # 初始化后,变为字典,key是entity,values是其向量(使用narray)。
        self.relationList = relationList#理由同上
        self.tripleList = tripleList#理由同上
        self.loss = 0
        self.L1 = L1

    def initialize(self):
        '''
        初始化向量
        '''
        entityVectorList = {}
        relationVectorList = {}
        for entity in self.entityList:  # 对entityList进行遍历
            n = 0
            entityVector = []
            while n < self.dim:
                ram = init(self.dim)  #调用init函数,返回一个实数类似1.3266
                entityVector.append(ram)   # 将ram 添加到实体向量中
                n += 1
            entityVector = norm(entityVector)  #调用norm函数,单位化
            entityVectorList[entity] = entityVector
        print("entityVector初始化完成,数量是%d"%len(entityVectorList))
        for relation in self. relationList:
            n = 0
            relationVector = []
            while n < self.dim:   # 循环dim次
                ram = init(self.dim)   #调用init函数,返回一个实数类似1.3266
                relationVector.append(ram)   # 将ram 添加到关系向量中
                n += 1
            relationVector = norm(relationVector)  #归一化
            relationVectorList[relation] = relationVector
        print("relationVectorList初始化完成,数量是%d"%len(relationVectorList))
        self.entityList = entityVectorList
        self.relationList = relationVectorList

    def transE(self, cI = 20):
        print("训练开始")
        for cycleIndex in range(cI):
            Sbatch = self.getSample(150)    #随机选取150个元素
            Tbatch = []     # 初始空 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
            for sbatch in Sbatch:
                tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))   #{((h,r,t),(h',r,t'))}
                if(tripletWithCorruptedTriplet not in Tbatch):
                    Tbatch.append(tripletWithCorruptedTriplet)
            self.update(Tbatch)
            if cycleIndex % 100 == 0:
                print("第%d次循环"%cycleIndex)
                print(self.loss)

                self.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
                self.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")

                # self.writeRelationVector("f:\\relationVector.txt")
                # self.writeEntilyVector("f:\\entityVector.txt")
                self.loss = 0

    def getSample(self, size):
        #—随机选取部分三元关系,Sbatch
        # sample(序列a,n)
        # 功能:从序列a中随机抽取n个元素,并将n个元素生以list形式返回。
        return sample(self.tripleList, size)

    def getCorruptedTriplet(self, triplet):
        '''
        training triplets with either the head or tail replaced by a random entity (but not both at the same time)
         #随机替换三元组的实体,h、t中任意一个被替换,但不同时替换。
        :param triplet:
        :return corruptedTriplet:
        '''
        i = uniform(-1, 1)  #uniform(a, b)#随机生成a,b之间的数,左闭右开。
        if i < 0:#小于0,打坏三元组的第一项
            while True:
                entityTemp = sample(self.entityList.keys(), 1)[0] #从entityList.key()中sample一个元素,以列表行驶返回第一个元素
                if entityTemp != triplet[0]:
                    break
            corruptedTriplet = (entityTemp, triplet[1], triplet[2])
        else:#大于等于0,打坏三元组的第二项
            while True:
                entityTemp = sample(self.entityList.keys(), 1)[0]
                if entityTemp != triplet[1]:
                    break
            corruptedTriplet = (triplet[0], entityTemp, triplet[2])
        return corruptedTriplet

    def update(self, Tbatch):
        copyEntityList = deepcopy(self.entityList) # 深拷贝  作为一个独立的存在 不会改变原来的值
        copyRelationList = deepcopy(self.relationList)
        
        for tripletWithCorruptedTriplet in Tbatch:
            # [((h,t,r),(h',t',r)),(())]
            headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]
            #tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
            tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]
            relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]
            headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]
            tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]
            
            headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]
            #tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
            tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]
            relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]
            headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]
            tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]
            
            if self.L1:
                # 计算正常情况下的误差
                distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
                # 计算随机情况下的误差
                distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
            else:
                distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
                distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
            # margin loss = max(0, margin + pos - neg)
            eg = self.margin + distTriplet - distCorruptedTriplet
            if eg > 0: #[function]+ 是一个取正值的函数
                self.loss += eg
                if self.L1:
                    # tempos = 2 * lr * (t - h - r)
                    tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
                    tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
                    tempPositiveL1 = []
                    tempNegtativeL1 = []
                    for i in range(self.dim):#不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)?
                        if tempPositive[i] >= 0:
                            tempPositiveL1.append(1)
                        else:
                            tempPositiveL1.append(-1)
                        if tempNegtative[i] >= 0:
                            tempNegtativeL1.append(1)
                        else:
                            tempNegtativeL1.append(-1)
                    tempPositive = array(tempPositiveL1)  
                    tempNegtative = array(tempNegtativeL1)

                else:
                    tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
                    tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
    
                headEntityVector = headEntityVector + tempPositive
                tailEntityVector = tailEntityVector - tempPositive
                relationVector = relationVector + tempPositive - tempNegtative
                headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative
                tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative

                #只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了
                copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)
                copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)
                copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)
                copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)
                copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)
                
        self.entityList = copyEntityList
        self.relationList = copyRelationList
        
    def writeEntilyVector(self, dir):
        print("写入实体")
        entityVectorFile = open(dir, 'w')
        for entity in self.entityList.keys():
            entityVectorFile.write(entity+"\t")
            entityVectorFile.write(str(self.entityList[entity].tolist()))
            entityVectorFile.write("\n")
        entityVectorFile.close()

    def writeRelationVector(self, dir):
        print("写入关系")
        relationVectorFile = open(dir, 'w')
        for relation in self.relationList.keys():
            relationVectorFile.write(relation + "\t")
            relationVectorFile.write(str(self.relationList[relation].tolist()))
            relationVectorFile.write("\n")
        relationVectorFile.close()

def init(dim):
    # uniform() 方法将随机生成下一个实数,它在[x, y]范围内。
    return uniform(-6/(dim**0.5), 6/(dim**0.5))

def distanceL1(h, t ,r):
    """
    trans e
    :param h:  head embendding
    :param t:   tail 
    :param r:  relation
    :return: 返回绝对误差和
    """
    s = h + r - t
    sum = fabs(s).sum()  # fabs() 方法返回数字的绝对值,如math.fabs(-10) 返回10.0。
    return sum

def distanceL2(h, t, r):
    """
    trans r
    :param h: 
    :param t: 
    :param r: 
    :return: 返回误差平方和
    """
    s = h + r - t
    sum = (s*s).sum()
    return sum
 
def norm(list):
    '''
    归一化
    :param 向量
    :return: 返回元素除以平方和后的数组
    '''
    var = linalg.norm(list)
    #x_norm=np.linalg.norm(x, ord=None, axis=None, keepdims=False)
    # 求范数  默认情况下,是求整体的矩阵元素平方和,再开根号。
    i = 0
    while i < len(list):
        list[i] = list[i]/var   #list中每一元素/var
        i += 1
    return array(list)

def openDetailsAndId(dir,sp="\t"):
    """
    :param dir: 路径  文件内容 皆为 /m/06rf7  0  其中entity 14951个
    :param sp: 
    :return: 返回idNum,名字列表
    """
    idNum = 0
    list = []
    with open(dir) as file:
        lines = file.readlines()  # 读取文件所有行
        for line in lines:    # 一行一行
            DetailsAndId = line.strip().split(sp)
            #strip(str)只能删除开头或是结尾的字符或是字符串
            # split(str) 按str分割 返回的是一个列表
            list.append(DetailsAndId[0])
            # 将名字添加到list
            idNum += 1
    return idNum, list

def openTrain(dir,sp="\t"):
    """
    /m/027rn   /m/06cx9   /location/country/form_of_government
    :param dir: 
    :param sp: 
    :return: 返回num 和关系总列表
    """
    num = 0
    list = []
    with open(dir) as file:
        lines = file.readlines()
        for line in lines:
            triple = line.strip().split(sp)
            if(len(triple)<3):  # 如果triple内没有三个元素,则结束本次循环
                continue
            list.append(tuple(triple))  # 将返回的三元列表 添加到list列表中
            num += 1
    return num, list

if __name__ == '__main__':
    dirEntity = r"F:\pycharm的项目\transe\data\FB15k\entity2id.txt"

    entityIdNum, entityList = openDetailsAndId(dirEntity)
    dirRelation = r"F:\pycharm的项目\transe\data\FB15k\relation2id.txt"

    relationIdNum, relationList = openDetailsAndId(dirRelation)
    dirTrain = r"F:\pycharm的项目\transe\data\FB15k\train.txt"

    tripleNum, tripleList = openTrain(dirTrain)
    print("打开TransE")
    # 在这里调用transE函数时,dim可以重新传参,
    transE = TransE(entityList,relationList,tripleList, margin=1, dim = 10)
    print("TranE初始化")
    transE.initialize()
    transE.transE(15000)

    transE.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
    transE.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/549508
推荐阅读
相关标签
  

闽ICP备14008679号