当前位置:   article > 正文

TransE算法详解【代码学习系列】【知识图谱】【表示学习】_transe代码

transe代码

1 代码来源

本代码来源于github项目地址,项目实现了TransE算法。下面结合项目代码,对TransE算法原理及实现进行详细说明。

2基本思想

TransE是一篇Bordes等人2013年发表在NIPS上的文章提出的算法。它的提出,是为了解决多关系数据(multi-relational data)的处理问题。我们现在有很多很多的知识库数据knowledge bases (KBs),比如Freebase、 Google Knowledge Graph 、 GeneOntology等等。
TransE的直观含义,就是TransE基于实体和关系的分布式向量表示,将每个三元组实例(head,relation,tail)中的关系relation看做从实体head到实体tail的翻译,通过不断调整h、r和t(head、relation和tail的向量),使(h + r) 尽可能与 t 相等,即 h + r = t。

TransE原理示意图

3处理流程

TransE算法可划分为表示向量初始化(步骤1—步骤3)、批训练数据集构建(步骤6—步骤11)和表示向量更新(步骤12、步骤5)三部分。

TransE算法
表示向量初始化(步骤1—步骤3)采用k维随机均匀分布对每个实体和每个关系进行初始化。
批训练数据集构建(步骤6—步骤11)则从训练集合中随机选出正面样本(h,r,t),然后基于正面样本,保持正面样本中h,r或r,t不变,改变t或h,获得负面样本,合并一起构成批批训练数据集。
表示向量更新(步骤12、步骤5)则采用随机梯度下降法,对批训练数据集中正面样本(h,r,t)和负样本的向量表示进行更新。

TranE直接对向量表示进行训练,每个实体(head或tail)对应一个k维向量,每个关系对应一个k维向量,所有的k维向量即为TransE模型的参数,需要训练的参数。

4代码实现

项目采用FB15K进行训练演示,FB15K包括三种类型数据,分别为实体数据,关系数据,和(h,t,r)数据。
实体数据:entity2id,包括两列,第一例为实体名字,第二列为实体ID,中间TAB键隔开。

/m/06rf7 0
/m/0c94fn 1
/m/016ywr 2
/m/01yjl 3
  • 1
  • 2
  • 3
  • 4

关系数据:relation2id,包括两列,第一例为关联名字,第二列为关系ID,中间TAB键隔开。

/people/appointed_role/appointment./people/appointment/appointed_by 0
/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency 1
/tv/tv_series_episode/guest_stars./tv/tv_guest_role/actor 2
/music/performance_role/track_performances./music/track_contribution/contributor 3
  • 1
  • 2
  • 3
  • 4

(h,t,r)数据,包括三列,第一列H实体名字,第二列T实体名字,第三列R关系名字,它们之间TAB键隔开。

/m/07pd_j /m/02l7c8 /film/film/genre
/m/06wxw /m/02fqwt /location/location/time_zones
/m/0d4fqn /m/03wh8kl /award/award_winner/awards_won./award/award_honor/award_winner
/m/07kcvl /m/0bgv8y /american_football/football_team/historical_roster./american_football/football_historical_roster_position/position_s
/m/012201 /m/0ckrnn /film/music_contributor/film
  • 1
  • 2
  • 3
  • 4
  • 5

(1)表示向量初始化

// 初始化程序模块

def init(dim):
   return uniform(-6/(dim**0.5), 6/(dim**0.5))

entityVectorList = {}
relationVectorList = {}
   for entity in self.entityList:  #先对实体表示向量初始化,关系表示向量初始化方法相同
   	n = 0	
   	entityVector = []
   	while n < self.dim:                  #产生dim个均匀分布的随机数
   		ram = init(self.dim)        #产生随机数
   		entityVector.append(ram)  
   		n += 1
   	entityVector = norm(entityVector)    #归一化为dim维向量
   	entityVectorList[entity] = entityVector   #构建字典,字典名为实体名,字典项为向量表示
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

(2)批训练数据集构建

// 首先从训练集中获取正样本,然后对正样本进行修改,获得负样本,正样本和负样本合在一起,作为一条样本
   def getSample(self, size):
       return sample(self.tripleList, size)
   def getCorruptedTriplet(self, triplet):
       '''
       training triplets with either the head or tail replaced by a random entity (but not both at the same time)
       :param triplet:
       :return corruptedTriplet:
       '''
       i = uniform(-1, 1)    #随机选择,小于0改变正样本中的H实体,变为负样本,大于0改变T实体,变为负样本
       if i < 0:#小于0,打坏三元组的第一项
           while True:
               entityTemp = sample(self.entityList.keys(), 1)[0]
               if entityTemp != triplet[0]:
                   break
           corruptedTriplet = (entityTemp, triplet[1], triplet[2])
       else:#大于等于0,打坏三元组的第二项
           while True:
               entityTemp = sample(self.entityList.keys(), 1)[0]
               if entityTemp != triplet[1]:
                   break
           corruptedTriplet = (triplet[0], entityTemp, triplet[2])
       return corruptedTriplet

Sbatch = self.getSample(150)
Tbatch = []#元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
for sbatch in Sbatch:
   tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))
   if(tripletWithCorruptedTriplet not in Tbatch):
   	Tbatch.append(tripletWithCorruptedTriplet)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

(3)表示向量更新

下面对上述代码进行详解。
首先提取用于单次训练的正样本和负样本向量表示。感觉没必须进行深度拷贝,直接计算就行,有没有大神指点一下。
数据提取

//正样本向量表示和负样本向量表示提取,正样本和负样本中的关系对象是相同的
headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]
relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]
#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple            
headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]
tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]
relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]       
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

损失计算

//计算正负样本的L1或L2距离范数,然后计算合页损失函数(hinge loss function)
if self.L1:
	distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
	distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, 						tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
else:
	distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
	distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, 		tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
eg = self.margin + distTriplet - distCorruptedTriplet

	
def distanceL1(h, t ,r):  #L1距离范数,向量每个原素绝对值的和
    s = h + r - t
    sum = fabs(s).sum()    
    return sum    
    
def distanceL2(h, t, r): #L2距离范数,向量每个元素平方的和
    s = h + r - t
    sum = (s*s).sum()
    return sum        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

梯度计算

//计算合页损失函数和梯度,[function]+ 是一个取正值的函数
if eg > 0:  #如果大于0,进行处理,小于等于0不处理
	self.loss += eg     
	if self.L1: #L1范数为每个向量元素的绝对值,如果向量元素大于0,其梯度为1,小于0,其梯度为-1。程序此处直接取梯度的负数,表示向量更新的时候直接加
		tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
		tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - 				headEntityVectorWithCorruptedTripletBeforeBatch -relationVectorBeforeBatch)
		tempPositiveL1 = []
     	tempNegtativeL1 = []
		for i in range(self.dim):
			if tempPositive[i] >= 0:  #根据距离正负,得到向量表示各元素的梯度
				tempPositiveL1.append(1)
			else:
	   			tempPositiveL1.append(-1)
         		if tempNegtative[i] >= 0:
          		tempNegtativeL1.append(1)
         		else:
          		tempNegtativeL1.append(-1)
              		tempPositive = array(tempPositiveL1)  
             	 	tempNegtative = array(tempNegtativeL1) 
	else: #L2范数为每个向量元平方的和,提取直接为距离向量
		tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
		tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - 			headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

梯度下降

// 对正负样本相关的(h,t,r)表示向量进行梯度更新,由于梯度求取的时候直接取负,因此梯度更新的时候直接加(即加变减,减变加)
headEntityVector = headEntityVector + tempPositive
tailEntityVector = tailEntityVector - tempPositive
relationVector = relationVector + tempPositive - tempNegtative
headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative
tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative
# 为了方便训练并避免过拟合,加上约束条件,进行归一化处理
copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)
copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)
copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)
copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)
copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)
# 更新完后,写入原变量
self.entityList = copyEntityList
self.relationList = copyRelationList      
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

5总结

TransE算法直接对每个实体、每个关系进行参数表示,抛开语义表示,根据实体关系间的三角向量关系约束,进行梯度优化,感觉存在参数过度,语义无关等问题。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/549504
推荐阅读
相关标签
  

闽ICP备14008679号