当前位置:   article > 正文

OpenKE-TransE代码阅读_openke代码

openke代码

OpenKE-TransE代码阅读

参考
TransE模型学习笔记
小黑笔记:transe模型
权值初始化
PyTorch权值初始化的十种方法

TransE伪代码
在这里插入图片描述

输入:训练集S={(h,l,t)},实体集L,关系集L,margin值y,嵌入向量维度k
在这里插入图片描述

γ \gamma γ:边距超参数。作用是,d[正三元组]-d[负三元组],会得到一个负数,margin是一个正数,使得整体式子是一个正数。随着loss的减小,d[正三元组]-d[负三元组]负数会越来越小,当其绝对值超过margin时,整个式子会变成负数。但是loss只取正数,当得到负数时,整个式子置为0,所以正负三元组最大距离为margin。

margin代表正负样本之间的最大距离,有了margin不会让负样本的d变得无限大

输入:relation2id.txt , entity2id.txt , train2id.txt
relation2id.txt:

/location/country/form_of_government 0
/tv/tv_program/regular_cast./tv/regular_tv_appearance/actor 1

entity2id.txt:

/m/027rn 0
/m/06cx9 1

train2id.txt:

0 1 0
2 3 1

代码逻辑
在这里插入图片描述

TrainDataLoader:数据采样,调用C++函数库方法

# 迭代器
def __iter__(self):
    if self.sampling_mode == "normal":
        return TrainDataSampler(self.nbatches, self.sampling)
    else:
        return TrainDataSampler(self.nbatches, self.cross_sampling)
# 调用sampling方法
def sampling(self):
    # 调用c++采样方法,传入的是地址
    self.lib.sampling(
        self.batch_h_addr,
        self.batch_t_addr,
        self.batch_r_addr,
        self.batch_y_addr,
        self.batch_size,
        self.negative_ent,
        self.negative_rel,
        0,
        self.filter,
        0,
        0
    )
    # 返回数据
    return {
        "batch_h": self.batch_h, 
        "batch_t": self.batch_t, 
        "batch_r": self.batch_r, 
        "batch_y": self.batch_y,
        "mode": "normal"
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

sampling中调用c++库中的Base.cpp的方法

// 启动线程,调用了getBatch方法
extern "C"
void sampling(
    INT *batch_h, // 地址
    INT *batch_t, 
    INT *batch_r, 
    REAL *batch_y, 
    INT batchSize, 
    INT negRate = 1, 
    INT negRelRate = 0, 
    INT mode = 0,
    bool filter_flag = true,
    bool p = false, 
    bool val_loss = false
) {
    pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t));
    Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter));
    for (INT threads = 0; threads < workThreads; threads++) {
        para[threads].id = threads;
        //...设置参数
        ....
    }
    for (INT threads = 0; threads < workThreads; threads++)
        pthread_join(pt[threads], NULL);// 启动线程,调用了getBatch方法
    free(pt);
    free(para);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

C++ getBatch方法,打乱三元组获得负样本

// sampling调用getBatch方法,打乱三元组
void* getBatch(void* con) {
    Parameter *para = (Parameter *)(con);
    INT id = para -> id;
    // 获取参数,省略
    bool p = para -> p;
    bool val_loss = para -> val_loss;
    INT mode = para -> mode;
    bool filter_flag = para -> filter_flag;
    INT lef, rig;
    if (batchSize % workThreads == 0) {
        lef = id * (batchSize / workThreads);
        rig = (id + 1) * (batchSize / workThreads);
    } else {
        lef = id * (batchSize / workThreads + 1);
        rig = (id + 1) * (batchSize / workThreads + 1);
        if (rig > batchSize) rig = batchSize;
    }
    REAL prob = 500;
    if (val_loss == false) {
        for (INT batch = lef; batch < rig; batch++) {
            INT i = rand_max(id, trainTotal);
            batch_h[batch] = trainList[i].h;
            batch_t[batch] = trainList[i].t;
            batch_r[batch] = trainList[i].r;
            batch_y[batch] = 1;
            INT last = batchSize;
            // 替换entity生成负样本
            for (INT times = 0; times < negRate; times ++) {
                if (mode == 0){ // mode==0
                    // 设置一个参数判断是随机替换头节点还是尾节点
                    if (bernFlag)
                        prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]);
                    if (randd(id) % 1000 < prob) {
                        batch_h[batch + last] = trainList[i].h;
                        batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r);
                        batch_r[batch + last] = trainList[i].r;
                    } else {
                        batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r);
                        batch_t[batch + last] = trainList[i].t;
                        batch_r[batch + last] = trainList[i].r;
                    }
                    batch_y[batch + last] = -1;
                    last += batchSize;
                } else {
                    // ...省略
                }
            }
            // 替换relation生成负样本
            for (INT times = 0; times < negRelRate; times++) {
                batch_h[batch + last] = trainList[i].h;
                batch_t[batch + last] = trainList[i].t;
                batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t, trainList[i].r, p);
                batch_y[batch + last] = -1;
                last += batchSize;
            }
        }
    }
    else{
        //...省略
    }
    pthread_exit(NULL);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

L2范数归一化
范数简单可以理解为用来表征向量空间中的距离,而距离的定义很抽象,只要满足非负、自反、三角不等式就可以称之为距离。采用范式作为正则项,可以防止模型训练过拟合。
在这里插入图片描述

TransE.py

def _calc(self, h, t, r, mode):
    if self.norm_flag:
        h = F.normalize(h, 2, -1) // l2范数归一化
        r = F.normalize(r, 2, -1)
        t = F.normalize(t, 2, -1)
    if mode != 'normal':
        h = h.view(-1, r.shape[0], h.shape[-1])
        t = t.view(-1, r.shape[0], t.shape[-1])
        r = r.view(-1, r.shape[0], r.shape[-1])
    if mode == 'head_batch':
        score = h + (r - t)
    else:
        score = (h + r) - t
    score = torch.norm(score, self.p_norm, -1).flatten()
    return score
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

权重初始化
网络模型的权值初始化是很重要的步骤,不恰当的权值可能会引发梯度消失或梯度爆炸。
在全连接层中,第i层的权重梯度会依赖于第i-1层的输出,可能会造成梯度消失或爆炸。因此要通过设置权值控制网络层输出的数值范围。

TransE.py

if margin == None or epsilon == None:
    # 初始化函数
    # Xavier,Kaiming
    # 均匀分布 torch.nn.init.xavier_uniform_(tensor, gain=1)
    nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
    nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
    self.embedding_range = nn.Parameter(
        torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
    )
    nn.init.uniform_(
        tensor = self.ent_embeddings.weight.data, 
        a = -self.embedding_range.item(), 
        b = self.embedding_range.item()
    )
    nn.init.uniform_(
        tensor = self.rel_embeddings.weight.data, 
        a= -self.embedding_range.item(), 
        b= self.embedding_range.item()
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

涉及两种权值初始化方式:
1. Xavier方式
基本思想是通过网络层时,输入和输出的方差相同,包括前向传播和后向传播,具体推导可以看这一篇介绍权值初始化。Xavier针对有非线性激活函数时的权值初始化,目标是保持数据的方差维持在1左右,主要针对饱和激活函数如 sigmoid 和 tanh 等。
公式推导是从“方差一致性”出发,初始化的分布有均匀分布和正态分布两种。
Xavier均匀分布
基本思想是通过网络层时,输入和输出的方差相同,包括前向传播和后向传播,具体可以看

torch.nn.init.xavier_uniform_(tensor, gain=1)
>>>  nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
  • 1
  • 2
  • 服从均匀分布 U ( − a , a ) U(-a,a) U(a,a),参数 a = g a i n ∗ 6 ( f a n i n + f a n o u t ) a=gain*\sqrt{\frac{6}{(fan_in+fan_out)}} a=gain(fanin+fanout)6
  • 参数 g a i n gain gain表示增益,大小由激活函数类型定义
tanh_gain = nn.init.calculate_gain('tanh')
  • 1

Xavier标准正态分布

torch.nn.init.xavier_normal_(tensor, gain=1)
  • 1
  • 服从正态分布, m e a n = 0 , s t d = g a i n ∗ 2 ( f a n i n + f a n o u t ) mean=0,std=gain*\sqrt{\frac{2}{(fan_{in}+fan_{out})}} mean=0,std=gain(fanin+fanout)2

2. torch.nn方式

torch.nn.init.uniform_(tensor, a=0, b=1)
使值服从均匀分布U(a,b)
>>>  nn.init.uniform_(
        tensor = self.ent_embeddings.weight.data, 
        a = -self.embedding_range.item(), 
        b = self.embedding_range.item()
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/549544
推荐阅读
相关标签
  

闽ICP备14008679号