当前位置:   article > 正文

TransE算法代码学习_transe-master

transe-master

参考学习:https://blog.csdn.net/weixin_44023339/article/details/100080669?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

python代码:

from random import uniform, sample
from numpy import *
from copy import deepcopy

class TransE:
    def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
        self.margin = margin#避免梯度值为0?
        self.learingRate = learingRate#学习率
        self.dim = dim#向量维度
        self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)。
        self.relationList = relationList#理由同上
        self.tripleList = tripleList#理由同上
        self.loss = 0
        self.L1 = L1

    def initialize(self):
        '''
        初始化向量
        '''
        entityVectorList = {
   }
        relationVectorList = {
   }
        for entity in self.entityList:
            n = 0
            entityVector = []
            while n < self.dim:
                ram = init(self.dim)#对于每个实体初始化dim个值组成向量
                entityVector.append(ram)
                n += 1
            entityVector = norm(entityVector)#归一化
            entityVectorList[entity] = entityVector
        print("entityVector初始化完成,数量是%d"%len(entityVectorList))
        for relation in self. relationList:
            n = 0
            relationVector = []
            while n < self.dim:
                ram = init(self.dim)#初始化的范围
                relationVector.append(ram)
                n += 1
            relationVector = norm(relationVector)#归一化
            relationVectorList[relation] = relationVector
        print("relationVectorList初始化完成,数量是%d"%len(relationVectorList))
        self.entityList = entityVectorList#{'实体名':[初始向量值],'':[],......,'':[]}实体名为KEY,向量值为value
        self.relationList = relationVectorList#{'关系名':[初始向量值],'':[],......,'':[]}

    def transE(self, cI = 20):
        print("训练开始")
        for cycleIndex in range(cI):
            Sbatch = self.getSample(150)#[(),(),....()]
            print(Sbatch)
            Tbatch = []#150个元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
            for sbatch in Sbatch:
                tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))
                if(tripletWithCorruptedTriplet not in Tbatch):
                    Tbatch.append(tripletWithCorruptedTriplet)
            self.update(Tbatch)
            if cycleIndex % 100 == 0:
                print("第%d次循环"%cycleIndex)
                print(self.loss)
                #self.writeRelationVector("D:\\transE-master\\transE-master\\relatio
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/549496
推荐阅读
相关标签
  

闽ICP备14008679号