当前位置:   article > 正文

TransE代码实践(很详细)_transe在代码中的体现

transe在代码中的体现

TranE是一篇Bordes等人2013年发表在NIPS上的文章提出的算法。它的提出,是为了解决多关系数据(multi-relational data)的处理问题。TransE的直观含义,就是TransE基于实体和关系的分布式向量表示,将每个三元组实例(head,relation,tail)中的关系relation看做从实体head到实体tail的翻译,通过不断调整h、r和t(head、relation和tail的向量),使(h + r) 尽可能与 t 相等,即 h + r = t。
这篇文章主要用来记录下TransE的代码。代码难点有两点,一是生成随机数的过程相对复杂一些;第二是生成伪数据时的流程即corrupt_head,其他照着主函数的执行流程应该都没问题。附一张corrupt_head执行样例。

#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <string>
#include <algorithm>
#include <pthread.h>
#include <iostream>
#include <sstream>

using namespace std;

const float pi = 3.141592653589793238462643383;

int transeThreads = 8;
int transeTrainTimes = 1000;
int nbatches = 10;
int dimension = 50;
float transeAlpha = 0.001;
float margin = 1;

string inPath = "../../";
string outPath = "../../";


int *lefHead, *rigHead;
int *lefTail, *rigTail;

struct Triple {
    int h, r, t;
};

Triple *trainHead, *trainTail, *trainList;

struct cmp_head {
    bool operator()(const Triple &a, const Triple &b) {
        return (a.h < b.h)||(a.h == b.h && a.r < b.r)||(a.h == b.h && a.r == b.r && a.t < b.t);
    }
};

struct cmp_tail {
    bool operator()(const Triple &a, const Triple &b) {
        return (a.t < b.t)||(a.t == b.t && a.r < b.r)||(a.t == b.t && a.r == b.r && a.h < b.h);
    }
};

/*
 There are some math functions for the program initialization.
 */
// 转换数组next_random中index为id的值
unsigned long long *next_random;
// 转换next_random索引为id的值并返回
unsigned long long randd(int id) {
    next_random[id] = next_random[id] * (unsigned long long)25214903917 + 11;
    return next_random[id];
}

int rand_max(int id, int x) { //小于x的随机数
    int res = randd(id) % x;
    while (res<0)
        res+=x;
    return res;
}

float rand(float min, float max) {
    return min + (max - min) * rand() / (RAND_MAX + 1.0);
}

// 返回x的概率密度函数
float normal(float x, float miu,float sigma) {
    return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma));
}

//返回一个大于或等于均值miu的概率密度并且属于[min,max]的数
float randn(float miu,float sigma, float min ,float max) {
    float x, y, dScope;
    do {
        x = rand(min,max);
        y = normal(x,miu,sigma);
        dScope=rand(0.0,normal(miu,miu,sigma));
    } while (dScope > y);
    return x;
}
// 向量标准化
void norm(float * con) {
    float x = 0;
    for (int  ii = 0; ii < dimension; ii++)
        x += (*(con + ii)) * (*(con + ii));
    x = sqrt(x);
    if (x>1)
        for (int ii=0; ii < dimension; ii++)
            *(con + ii) /= x;
}

/*
 Read triples from the training file.
 */

int relationTotal, entityTotal, tripleTotal;
float *relationVec, *entityVec;
float *relationVecDao, *entityVecDao;

// 将实体id 关系id 三元组导入、初始化要训练的向量
void init() {
    FILE *fin;
    int tmp;
    
    fin = fopen((inPath + "relation2id.txt").c_str(), "r"); //fopen创建或打开,c_str()返回一个指向正规C字符串的临时的指针常量
    tmp = fscanf(fin, "%d", &relationTotal); //获取总的total数
    fclose(fin);
    //构建关系向量
    relationVec = (float *)calloc(relationTotal * dimension, sizeof(float));//分配realtionTotal*dimension
    for (int i=0;i<relationTotal; i++) {
        for (int ii=0;ii<dimension; ii++)
            relationVec[i*dimension+ii] = randn(0,1.0/dimension,-6/sqrt(dimension),6/sqrt(dimension));//(miu,sigma,min,max)
    }
    fin = fopen((inPath + "entity2id.txt").c_str(), "r");
    tmp = fscanf(fin,"%d",&entityTotal);
    fclose(fin);
    entityVec = (float *)calloc(entityTotal * dimension, sizeof(float));
    for (int i=0;i<entityTotal;i++) {
        for (int ii=0;ii<dimension;ii++)
            entityVec[i * dimension + ii] = randn(0, 1.0 / dimension, -6 / sqrt(dimension), 6 / sqrt(dimension));
        norm(entityVec+i*dimension); // 单个entity向量标准化
    }
    fin = fopen((inPath + "triple2id.txt").c_str(), "r");
    tmp = fscanf(fin, "%d", &tripleTotal);
    trainHead = (Triple *)calloc(tripleTotal, sizeof(Triple));
    trainTail = (Triple *)calloc(tripleTotal, sizeof(Triple));
    trainList = (Triple *)calloc(tripleTotal, sizeof(Triple));
    tripleTotal = 0;
    //trainlist存储三元组,复制给 trainHead和trainTail
    while (fscanf(fin, "%d", &trainList[tripleTotal].h) == 1) {
        tmp = fscanf(fin, "%d", &trainList[tripleTotal].t);
        tmp = fscanf(fin, "%d", &trainList[tripleTotal].r);
        trainHead[tripleTotal].h = trainList[tripleTotal].h;
        trainHead[tripleTotal].t = trainList[tripleTotal].t;
        trainHead[tripleTotal].r = trainList[tripleTotal].r;
        trainTail[tripleTotal].h = trainList[tripleTotal].h;
        trainTail[tripleTotal].t = trainList[tripleTotal].t;
        trainTail[tripleTotal].r = trainList[tripleTotal].r;
        tripleTotal++;
    }
    fclose(fin);
    //按照head和tail排序
    sort(trainHead, trainHead + tripleTotal, cmp_head());
    sort(trainTail, trainTail + tripleTotal, cmp_tail());
    
    lefHead = (int *)calloc(entityTotal, sizeof(int));
    rigHead = (int *)calloc(entityTotal, sizeof(int));
    lefTail = (int *)calloc(entityTotal, sizeof(int));
    rigTail = (int *)calloc(entityTotal, sizeof(int));
    memset(rigHead, -1, sizeof(int)*entityTotal); //初始化
    memset(rigTail, -1, sizeof(int)*entityTotal);
    for (int i=1;i<tripleTotal;i++) {
        if (trainTail[i].t != trainTail[i - 1].t) {
            rigTail[trainTail[i - 1].t] = i - 1;  // 将索引为i-1的t的值置为i-1,意思TrainTail[i-1]的值的终止点
            lefTail[trainTail[i].t] = i; // 将索引为i的t的值置为i,意思是TrainTail[i]与左侧值不同,意思TrainTail[i]的值的终止点
        }
        if (trainHead[i].h != trainHead[i - 1].h) {
            rigHead[trainHead[i - 1].h] = i - 1;
            lefHead[trainHead[i].h] = i;
        }
    }
    rigHead[trainHead[tripleTotal - 1].h] = tripleTotal - 1;
    rigTail[trainTail[tripleTotal - 1].t] = tripleTotal - 1;
    
    relationVecDao = (float*)calloc(dimension * relationTotal, sizeof(float));
    entityVecDao = (float*)calloc(dimension * entityTotal, sizeof(float));
}

/*
 Training process of transE.
 */

int transeLen;
int transeBatch;
float res;

// 计算距离 d(e1-e2-r)=sum(|e1-e2-r|)
float calc_sum(int e1, int e2, int rel) {
    float sum=0;
    int last1 = e1 * dimension;
    int last2 = e2 * dimension;
    int lastr = rel * dimension;
    for (int ii=0; ii < dimension; ii++) {
        // 从entityVec取值计算loss
        sum += fabs(entityVec[last2 + ii] - entityVec[last1 + ii] - relationVec[lastr + ii]);
    }
    return sum;
}
// 更新梯度,正样本试图缩小梯度,负样本试图增大梯度
void gradient(int e1_a, int e2_a, int rel_a, int e1_b, int e2_b, int rel_b) {
    int lasta1 = e1_a * dimension;
    int lasta2 = e2_a * dimension;
    int lastar = rel_a * dimension;
    int lastb1 = e1_b * dimension;
    int lastb2 = e2_b * dimension;
    int lastbr = rel_b * dimension;
    for (int ii=0; ii  < dimension; ii++) {
        float x;
        x = (entityVec[lasta2 + ii] - entityVec[lasta1 + ii] - relationVec[lastar + ii]);
        if (x > 0)
            x = -transeAlpha;
        else
            x = transeAlpha;
        relationVec[lastar + ii] -= x;
        entityVec[lasta1 + ii] -= x;
        entityVec[lasta2 + ii] += x;
        x = (entityVec[lastb2 + ii] - entityVec[lastb1 + ii] - relationVec[lastbr + ii]);
        if (x > 0)
            x = transeAlpha;
        else
            x = -transeAlpha;
        relationVec[lastbr + ii] -=  x;
        entityVec[lastb1 + ii] -= x;
        entityVec[lastb2 + ii] += x;
    }
}

// 计算距离并更新梯度
void train_kb(int e1_a, int e2_a, int rel_a, int e1_b, int e2_b, int rel_b) {
    float sum1 = calc_sum(e1_a, e2_a, rel_a);
    float sum2 = calc_sum(e1_b, e2_b, rel_b);
    // 不满足条件则需要更新梯度
    if (sum1 + margin > sum2) {
        res += margin + sum1 - sum2;
        gradient(e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
    }
}
// 根据相同的h返回一个假的样本t,获取三元组中相同h对应的r
int corrupt_head(int id, int h, int r) {
    int lef, rig, mid, ll, rr;
    lef = lefHead[h] - 1;
    rig = rigHead[h];
    while (lef + 1 < rig) { //则该值不止一个
        mid = (lef + rig) >> 1; // 除2
        if (trainHead[mid].r >= r) rig = mid; else
            lef = mid;
    }
    ll = rig; // r值对应的index
    lef = lefHead[h];
    rig = rigHead[h] + 1;
    while (lef + 1 < rig) {
        mid = (lef + rig) >> 1;
        if (trainHead[mid].r <= r) lef = mid; else
            rig = mid;
    }
    rr = lef;
    int tmp = rand_max(id, entityTotal - (rr - ll + 1)); //生成一个小于entityTotal - (rr - ll + 1)的随机数
    if (tmp < trainHead[ll].t) return tmp; //小于初始t 直接返回
    if (tmp > trainHead[rr].t - rr + ll - 1) return tmp + rr - ll + 1; //
    lef = ll, rig = rr + 1;
    while (lef + 1 < rig) {
        mid = (lef + rig) >> 1;
        if (trainHead[mid].t - mid + ll - 1 < tmp)
            lef = mid;
        else
            rig = mid;
    }
    return tmp + lef - ll + 1;
}

int corrupt_tail(int id, int t, int r) {
    int lef, rig, mid, ll, rr;
    lef = lefTail[t] - 1;
    rig = rigTail[t];
    while (lef + 1 < rig) {
        mid = (lef + rig) >> 1;
        if (trainTail[mid].r >= r) rig = mid; else
            lef = mid;
    }
    ll = rig;
    lef = lefTail[t];
    rig = rigTail[t] + 1;
    while (lef + 1 < rig) {
        mid = (lef + rig) >> 1;
        if (trainTail[mid].r <= r) lef = mid; else
            rig = mid;
    }
    rr = lef;
    int tmp = rand_max(id, entityTotal - (rr - ll + 1));
    if (tmp < trainTail[ll].h) return tmp;
    if (tmp > trainTail[rr].h - rr + ll - 1) return tmp + rr - ll + 1;
    lef = ll, rig = rr + 1;
    while (lef + 1 < rig) {
        mid = (lef + rig) >> 1;
        if (trainTail[mid].h - mid + ll - 1 < tmp)
            lef = mid;
        else
            rig = mid;
    }
    return tmp + lef - ll + 1;
}
// 接受线程id作为输入,调用corrupt生成正负样本,train_kb进行训练
void* transetrainMode(void *con) {
    int id;
    id = (unsigned long long)(con); //补0即可
    next_random[id] = rand();
    for (int k = transeBatch / transeThreads; k >= 0; k--) { // 一个batch训练的样本数按照线程均分
        int j;
        // 生成一个样本随机的样本id
        int i = rand_max(id, transeLen); // i为生成的随机数
        int pr = 500; //一半的概率1/2决定生成 伪head tail
        if (randd(id) % 1000 < pr) {
            // 选择正、负样本作为训练输入
            j = corrupt_head(id, trainList[i].h, trainList[i].r);
            train_kb(trainList[i].h, trainList[i].t, trainList[i].r, trainList[i].h, j, trainList[i].r);
        } else {
            j = corrupt_tail(id, trainList[i].t, trainList[i].r);
            train_kb(trainList[i].h, trainList[i].t, trainList[i].r, j, trainList[i].t, trainList[i].r);
        }
        norm(relationVec + dimension * trainList[i].r); // 标准化
        norm(entityVec + dimension * trainList[i].h);
        norm(entityVec + dimension * trainList[i].t);
        norm(entityVec + dimension * j);
    }
    pthread_exit(NULL);
}

// 创建线程执行 调用transetrainMode 模型训练
void train_transe(void *con) {
    transeLen = tripleTotal;
    transeBatch = transeLen / nbatches; // 一个batch的样本大小
    next_random = (unsigned long long *)calloc(transeThreads, sizeof(unsigned long long)); // 根据线程数创建表示线程的数组
    for (int epoch = 0; epoch < transeTrainTimes; epoch++) {
        res = 0;
        // 一个epoch包含nbatches个batch,每个batch再按线程划分
        for (int batch = 0; batch < nbatches; batch++) {
            pthread_t *pt = (pthread_t *)malloc(transeThreads * sizeof(pthread_t)); // 表示线程id,可以认为unsigned long int类型
            for (long a = 0; a < transeThreads; a++)
                pthread_create(&pt[a], NULL, transetrainMode,  (void*)a); // 创建线程(指向线程标识符的指针,线程属性,运行函数的地址,运行函数的参数)
            for (long a = 0; a < transeThreads; a++)
                pthread_join(pt[a], NULL); //以阻塞的方式等待thread指定的线程结束,主线程等待直到等待的线程结束
            free(pt);
        }
        printf("epoch %d %f\n", epoch, res);
    }
}

/*
 save result
 */

void out_transe() {
    stringstream ss;
    ss << dimension;
    string dim = ss.str();
    
    FILE* f2 = fopen((outPath + "TransE_relation2vec_" + dim + ".vec").c_str(), "w");
    FILE* f3 = fopen((outPath + "TransE_entity2vec_" + dim + ".vec").c_str(), "w");
    for (int i=0; i < relationTotal; i++) {
        int last = dimension * i;
        for (int ii = 0; ii < dimension; ii++)
            fprintf(f2, "%.6f\t", relationVec[last + ii]);
        fprintf(f2,"\n");
    }
    for (int  i = 0; i < entityTotal; i++) {
        int last = i * dimension;
        for (int ii = 0; ii < dimension; ii++)
            fprintf(f3, "%.6f\t", entityVec[last + ii] );
        fprintf(f3,"\n");
    }
    fclose(f2);
    fclose(f3);
}
/*
 Main function
 */
int main() {
    time_t start = time(NULL);
    init();
    train_transe(NULL);
    out_transe();
    cout << time(NULL) - start << " s" << endl;
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/549472
推荐阅读
相关标签
  

闽ICP备14008679号