赞
踩
清华大学建立了一个开放的知识表示框架OpenKE:OpenKE - An Open Source Framework for knowledge graph。
该框架集成了TransE 、TransH、TransR、TransD、RESCAL、DistMult、HolE、ComplEx等知识表示学习算法,其GitHub地址为:OpenKE - GitHub,并包含了训练用到的数据集及测试集。
本文主要是对其中的TransE模型,按照自己的理解增加的一下注释,不改变原有代码本身的实现。以train_transe_FB15K237.py为例。
train_transe_FB15K237.py中训练数据加载的Python代码
# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/FB15K237/",
nbatches = 100,
threads = 8,
sampling_mode = "normal",
bern_flag = 1,
filter_flag = 1,
neg_ent = 25,
neg_rel = 0)
其中的8个参数会传到TrainDataLoader.py中class TrainDataLoader(object)的初始化函数__init__中
class TrainDataLoader(object): #初始化实例对象 def __init__(self, in_path = "./", #数据所在的根目录 tri_file = None, #训练集 ent_file = None, #实体集 rel_file = None, #关系集 batch_size = None, #批次大小 nbatches = None, #批次数 threads = 8, #线程数量 sampling_mode = "normal", #采样方法 bern_flag = False, filter_flag = True, neg_ent = 1, neg_rel = 0): base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) self.lib = ctypes.cdll.LoadLibrary(base_file) """argtypes""" self.lib.sampling.argtypes = [ #C与Python数据类型的转换,回调 ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64 ] self.in_path = in_path #路径 self.tri_file = tri_file self.ent_file = ent_file self.rel_file = rel_file if in_path != None: #将训练集、实体集、关系集路径分别存放在对应的属性中 self.tri_file = in_path + "train2id.txt" self.ent_file = in_path + "entity2id.txt" self.rel_file = in_path + "relation2id.txt" """set essential parameters""" self.work_threads = threads self.nbatches = nbatches self.batch_size = batch_size self.bern = bern_flag self.filter = filter_flag self.negative_ent = neg_ent #负例实体 self.negative_rel = neg_rel #负例关系 self.sampling_mode = sampling_mode self.cross_sampling_flag = 0 self.read()
初始化函数的最后,会调用TrainDataLoader.py中的read()函数,读取训练数据
#读训练数据 def read(self): if self.in_path != None: self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) else: self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2)) self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2)) self.lib.setRelPath(ctypes.create_string_buffer(self.rel_file.encode(), len(self.rel_file) * 2)) self.lib.setBern(self.bern) self.lib.setWorkThreads(self.work_threads) #设置工作线程 self.lib.randReset() #重置所有线程的随机种子 self.lib.importTrainFiles() #读取训练集 self.relTotal = self.lib.getRelationTotal() #获取关系总数 self.entTotal = self.lib.getEntityTotal() #获取实体总数 self.tripleTotal = self.lib.getTrainTotal() #获取训练三元组总数 if self.batch_size == None: self.batch_size = self.tripleTotal // self.nbatches #根据样本总数与batches的大小,计算batch_size的大小 if self.nbatches == None: self.nbatches = self.tripleTotal // self.batch_size #根据样本总数与batch_size的大小,计算batches的大小 self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel) ''' np.zeros返回来一个给定形状和类型的用0填充的数组; zeros(shape, dtype=float, order=‘C’) shape:形状 dtype:数据类型,可选参数,默认numpy.float64 order:可选参数,c代表与c语言类似,行优先;F代表列优先 ''' #定义batch数据,包含头实体、尾实体、关系、标签,以及他们对应的数组首地址,其中标签batch_y,1表示原始三元组,-1表示替换后的三元组 self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32) self.batch_h_addr = self.batch_h.__array_interface__["data"][0] self.batch_t_addr = self.batch_t.__array_interface__["data"][0] self.batch_r_addr = self.batch_r.__array_interface__["data"][0] self.batch_y_addr = self.batch_y.__array_interface__["data"][0]
在read()函数中,会调用很多的C++函数,比如setInPath、setTrainPath、setEntPath、setRelPath、setBern、setWorkThreads、randReset、importTrainFiles、getRelationTotal、getEntityTotal、getTrainTotal等,由于这些C++函数太多,本文就直接贴单个文件的注释,而不再拆开成单个函数的注释。
#include "Setting.h" #include "Random.h" #include "Reader.h" #include "Corrupt.h" #include "Test.h" #include <cstdlib> #include <pthread.h> extern "C" void setInPath(char *path); extern "C" void setTrainPath(char *path); extern "C" void setValidPath(char *path); extern "C" void setTestPath(char *path); extern "C" void setEntPath(char *path); extern "C" void setRelPath(char *path); extern "C" void setOutPath(char *path); extern "C" void setWorkThreads(INT threads); extern "C" void setBern(INT con); extern "C" INT getWorkThreads(); extern "C" INT getEntityTotal(); extern "C" INT getRelationTotal(); extern "C" INT getTripleTotal(); extern "C" INT getTrainTotal(); extern "C" INT getTestTotal(); extern "C" INT getValidTotal(); extern "C" void randReset(); extern "C" void importTrainFiles(); struct Parameter { INT id; INT *batch_h; INT *batch_t; INT *batch_r; REAL *batch_y; INT batchSize; INT negRate; INT negRelRate; bool p; bool val_loss; INT mode; bool filter_flag; }; //获取Batch void* getBatch(void* con) { Parameter *para = (Parameter *)(con); //将参数con赋值给para,也就是将sampling函数中的para和threads //将para相应的值存到对应的局部变量中 INT id = para -> id; INT *batch_h = para -> batch_h; INT *batch_t = para -> batch_t; INT *batch_r = para -> batch_r; REAL *batch_y = para -> batch_y; INT batchSize = para -> batchSize; INT negRate = para -> negRate; INT negRelRate = para -> negRelRate; bool p = para -> p; bool val_loss = para -> val_loss; INT mode = para -> mode; bool filter_flag = para -> filter_flag; INT lef, rig; if (batchSize % workThreads == 0) { //如果batchSize刚好能被线程数整除,也就是一个batch的大小刚好能被均分到每一个线程 lef = id * (batchSize / workThreads); rig = (id + 1) * (batchSize / workThreads); } else { //反之 lef = id * (batchSize / workThreads + 1); rig = (id + 1) * (batchSize / workThreads + 1); if (rig > batchSize) rig = batchSize; } REAL prob = 500; if (val_loss == false) { //采样负例三元组 for (INT batch = lef; batch < rig; batch++) { //根据进程ID,随机采样训练三元组 INT i = rand_max(id, trainTotal); batch_h[batch] = trainList[i].h; batch_t[batch] = trainList[i].t; batch_r[batch] = trainList[i].r; batch_y[batch] = 1; INT last = batchSize; for (INT times = 0; times < negRate; times ++) { if (mode == 0){ if (bernFlag) prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]); if (randd(id) % 1000 < prob) { batch_h[batch + last] = trainList[i].h; batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); //根据头实体、关系,通过二分搜索,获取交换后(错误)的尾实体 batch_r[batch + last] = trainList[i].r; } else { batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); //根据尾实体、关系,通过二分搜索,获取交换后(错误)的头实体 batch_t[batch + last] = trainList[i].t; batch_r[batch + last] = trainList[i].r; } batch_y[batch + last] = -1; last += batchSize; } else { if(mode == -1){ batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); batch_t[batch + last] = trainList[i].t; batch_r[batch + last] = trainList[i].r; } else { batch_h[batch + last] = trainList[i].h; batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); batch_r[batch + last] = trainList[i].r; } batch_y[batch + last] = -1; last += batchSize; } } for (INT times = 0; times < negRelRate; times++) { batch_h[batch + last] = trainList[i].h; batch_t[batch + last] = trainList[i].t; batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t, trainList[i].r, p); 根据头实体、尾实体,通过二分搜索,获取交换后(错误)的关系 batch_y[batch + last] = -1; last += batchSize; } } } else { //验证集 for (INT batch = lef; batch < rig; batch++) { batch_h[batch] = validList[batch].h; batch_t[batch] = validList[batch].t; batch_r[batch] = validList[batch].r; batch_y[batch] = 1; } } pthread_exit(NULL); //线程终止 } extern "C" void sampling( INT *batch_h, INT *batch_t, INT *batch_r, REAL *batch_y, INT batchSize, INT negRate = 1, INT negRelRate = 0, INT mode = 0, bool filter_flag = true, bool p = false, bool val_loss = false ) { pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t)); //根据线程数量,向内存分配指定的大小 Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter)); //根据线程数量,以及Parameter结构体的大小,向内存分配指定的大小 //初始化para结构体 for (INT threads = 0; threads < workThreads; threads++) { para[threads].id = threads; para[threads].batch_h = batch_h; para[threads].batch_t = batch_t; para[threads].batch_r = batch_r; para[threads].batch_y = batch_y; para[threads].batchSize = batchSize; para[threads].negRate = negRate; para[threads].negRelRate = negRelRate; para[threads].p = p; para[threads].val_loss = val_loss; para[threads].mode = mode; para[threads].filter_flag = filter_flag; /* 创建线程 int pthread_create( pthread_t *restrict tidp, //新创建的线程ID指向的内存单元。 const pthread_attr_t *restrict attr, //线程属性,默认为NULL void *(*start_rtn)(void *), //新创建的线程从start_rtn函数的地址开始运行 void *restrict arg //默认为NULL。若上述函数需要参数,将参数放入结构中并将地址作为arg传入。 ); */ pthread_create(&pt[threads], NULL, getBatch, (void*)(para+threads)); } /* int pthread_join( pthread_t thread, void * * value_ptr ); 函数pthread_join的作用是,等待一个线程终止。 调用pthread_join的线程将被挂起直到参数thread所代表的线程终止时为止。pthread_join是一个线程阻塞函数,调用它的函数将一直等到被等待的线程结束为止。 如果value_ptr不为NULL,那么线程thread的返回值存储在该指针指向的位置。该返回值可以是由pthread_exit给出的值,或者该线程被取消而返回PTHREAD_CANCELED。 */ for (INT threads = 0; threads < workThreads; threads++) pthread_join(pt[threads], NULL); free(pt); //将通过malloc分配pt、para的内存释放 free(para); } int main() { importTrainFiles(); return 0; }
#ifndef CORRUPT_H #define CORRUPT_H #include "Random.h" #include "Triple.h" #include "Reader.h" INT corrupt_head(INT id, INT h, INT r, bool filter_flag = true) { INT lef, rig, mid, ll, rr; if (not filter_flag) { //如果filter_flag为false INT tmp = rand_max(id, entityTotal - 1); //获取[0,entityTotal-1)的随机数 if (tmp < h) //如果随机数比h的ID小,直接返回随机数 return tmp; else //反之,返回随机数+1 return tmp + 1; } //二分搜索,查找r在trainHead中的下标,即mid lef = lefHead[h] - 1; rig = rigHead[h]; while (lef + 1 < rig) { mid = (lef + rig) >> 1; //将结果右移一位,等价于(lef+rig)/2 if (trainHead[mid].r >= r) rig = mid; else //如果mid对应的关系大于或等于r,则rig=mid lef = mid; //反之,lef=mid } ll = rig; //二分搜索,同上,只是查找的范围移动了一位 lef = lefHead[h]; rig = rigHead[h] + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainHead[mid].r <= r) lef = mid; else rig = mid; } rr = lef; INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); if (tmp < trainHead[ll].t) return tmp; //如果tmp小于ll在trainHead中的头实体,则返回tmp if (tmp > trainHead[rr].t - rr + ll - 1) return tmp + rr - ll + 1; //如果大于 lef = ll, rig = rr + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainHead[mid].t - mid + ll - 1 < tmp) lef = mid; else rig = mid; } return tmp + lef - ll + 1; } INT corrupt_tail(INT id, INT t, INT r, bool filter_flag = true) { INT lef, rig, mid, ll, rr; if (not filter_flag) { INT tmp = rand_max(id, entityTotal - 1); if (tmp < t) return tmp; else return tmp + 1; } lef = lefTail[t] - 1; rig = rigTail[t]; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainTail[mid].r >= r) rig = mid; else lef = mid; } ll = rig; lef = lefTail[t]; rig = rigTail[t] + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainTail[mid].r <= r) lef = mid; else rig = mid; } rr = lef; INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); if (tmp < trainTail[ll].h) return tmp; if (tmp > trainTail[rr].h - rr + ll - 1) return tmp + rr - ll + 1; lef = ll, rig = rr + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainTail[mid].h - mid + ll - 1 < tmp) lef = mid; else rig = mid; } return tmp + lef - ll + 1; } INT corrupt_rel(INT id, INT h, INT t, INT r, bool p = false, bool filter_flag = true) { INT lef, rig, mid, ll, rr; if (not filter_flag) { INT tmp = rand_max(id, relationTotal - 1); if (tmp < r) return tmp; else return tmp + 1; } lef = lefRel[h] - 1; rig = rigRel[h]; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainRel[mid].t >= t) rig = mid; else lef = mid; } ll = rig; lef = lefRel[h]; rig = rigRel[h] + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainRel[mid].t <= t) lef = mid; else rig = mid; } rr = lef; INT tmp; if(p == false) { tmp = rand_max(id, relationTotal - (rr - ll + 1)); } else { INT start = r * (relationTotal - 1); REAL sum = 1; bool *record = (bool *)calloc(relationTotal - 1, sizeof(bool)); for (INT i = ll; i <= rr; ++i){ if (trainRel[i].r > r){ sum -= prob[start + trainRel[i].r-1]; record[trainRel[i].r-1] = true; } else if (trainRel[i].r < r){ sum -= prob[start + trainRel[i].r]; record[trainRel[i].r] = true; } } REAL *prob_tmp = (REAL *)calloc(relationTotal-(rr-ll+1), sizeof(REAL)); INT cnt = 0; REAL rec = 0; for (INT i = start; i < start + relationTotal - 1; ++i) { if (record[i-start]) continue; rec += prob[i] / sum; prob_tmp[cnt++] = rec; } REAL m = rand_max(id, 10000) / 10000.0; lef = 0; rig = cnt - 1; while (lef < rig) { mid = (lef + rig) >> 1; if (prob_tmp[mid] < m) lef = mid + 1; else rig = mid; } tmp = rig; free(prob_tmp); free(record); } if (tmp < trainRel[ll].r) return tmp; if (tmp > trainRel[rr].r - rr + ll - 1) return tmp + rr - ll + 1; lef = ll, rig = rr + 1; while (lef + 1 < rig) { mid = (lef + rig) >> 1; if (trainRel[mid].r - mid + ll - 1 < tmp) lef = mid; else rig = mid; } return tmp + lef - ll + 1; } bool _find(INT h, INT t, INT r) { INT lef = 0; INT rig = tripleTotal - 1; INT mid; while (lef + 1 < rig) { INT mid = (lef + rig) >> 1; if ((tripleList[mid]. h < h) || (tripleList[mid]. h == h && tripleList[mid]. r < r) || (tripleList[mid]. h == h && tripleList[mid]. r == r && tripleList[mid]. t < t)) lef = mid; else rig = mid; } if (tripleList[lef].h == h && tripleList[lef].r == r && tripleList[lef].t == t) return true; if (tripleList[rig].h == h && tripleList[rig].r == r && tripleList[rig].t == t) return true; return false; } INT corrupt(INT h, INT r){ INT ll = tail_lef[r]; INT rr = tail_rig[r]; INT loop = 0; INT t; while(true) { t = tail_type[rand(ll, rr)]; if (not _find(h, t, r)) { return t; } else { loop ++; if (loop >= 1000) { return corrupt_head(0, h, r); } } } } #endif
#ifndef RANDOM_H #define RANDOM_H #include "Setting.h" #include <cstdlib> // the random seeds for all threads. unsigned long long *next_random; // reset the random seeds for all threads extern "C" void randReset() { //calloc: 在内存的动态存储区中分配workThreads个长度为size的连续空间,函数返回一个指向分配起始地址的指针;如果分配不成功,返回NULL。 next_random = (unsigned long long *)calloc(workThreads, sizeof(unsigned long long)); for (INT i = 0; i < workThreads; i++) next_random[i] = rand(); } // get a random interger for the id-th thread with the corresponding random seed. unsigned long long randd(INT id) { next_random[id] = next_random[id] * (unsigned long long)(25214903917) + 11; return next_random[id]; } // get a random interger from the range [0,x) for the id-th thread. INT rand_max(INT id, INT x) { INT res = randd(id) % x; while (res < 0) res += x; return res; } // get a random interger from the range [a,b) for the id-th thread. INT rand(INT a, INT b){ return (rand() % (b-a))+ a; } #endif
#ifndef READER_H #define READER_H #include "Setting.h" #include "Triple.h" #include <cstdlib> #include <algorithm> #include <iostream> #include <cmath> INT *freqRel, *freqEnt; INT *lefHead, *rigHead; INT *lefTail, *rigTail; INT *lefRel, *rigRel; REAL *left_mean, *right_mean; REAL *prob; Triple *trainList; Triple *trainHead; Triple *trainTail; Triple *trainRel; INT *testLef, *testRig; INT *validLef, *validRig; extern "C" void importProb(REAL temp){ if (prob != NULL) free(prob); FILE *fin; fin = fopen((inPath + "kl_prob.txt").c_str(), "r"); printf("Current temperature:%f\n", temp); prob = (REAL *)calloc(relationTotal * (relationTotal - 1), sizeof(REAL)); INT tmp; for (INT i = 0; i < relationTotal * (relationTotal - 1); ++i){ tmp = fscanf(fin, "%f", &prob[i]); } REAL sum = 0.0; for (INT i = 0; i < relationTotal; ++i) { for (INT j = 0; j < relationTotal-1; ++j){ REAL tmp = exp(-prob[i * (relationTotal - 1) + j] / temp); sum += tmp; prob[i * (relationTotal - 1) + j] = tmp; } for (INT j = 0; j < relationTotal-1; ++j){ prob[i*(relationTotal-1)+j] /= sum; } sum = 0; } fclose(fin); } extern "C" void importTrainFiles() { printf("The toolkit is importing datasets.\n"); FILE *fin; int tmp; //读取关系数据集 if (rel_file == "") fin = fopen((inPath + "relation2id.txt").c_str(), "r"); else fin = fopen(rel_file.c_str(), "r"); //打开文件输入流 tmp = fscanf(fin, "%ld", &relationTotal); //读取第一行,作为关系的总数,赋值到relationTotal printf("The total of relations is %ld.\n", relationTotal); fclose(fin); //关闭文件输入流 //读取实体数据集 if (ent_file == "") fin = fopen((inPath + "entity2id.txt").c_str(), "r"); else fin = fopen(ent_file.c_str(), "r"); tmp = fscanf(fin, "%ld", &entityTotal); printf("The total of entities is %ld.\n", entityTotal); fclose(fin); //读取训练数据集,三元组,头实体ID,尾实体ID,关系ID if (train_file == "") fin = fopen((inPath + "train2id.txt").c_str(), "r"); else fin = fopen(train_file.c_str(), "r"); tmp = fscanf(fin, "%ld", &trainTotal); //根据训练集的大小,内存分配对应的空间大小给trainList、trainHead、trainTail、trainRel trainList = (Triple *)calloc(trainTotal, sizeof(Triple)); trainHead = (Triple *)calloc(trainTotal, sizeof(Triple)); trainTail = (Triple *)calloc(trainTotal, sizeof(Triple)); trainRel = (Triple *)calloc(trainTotal, sizeof(Triple)); //根据关系总数,分配内存给freqRel,freqRel表示关系的频率 freqRel = (INT *)calloc(relationTotal, sizeof(INT)); //根据实体总数,分配内存给freqEnt,freqEnt表示实体的频率 freqEnt = (INT *)calloc(entityTotal, sizeof(INT)); for (INT i = 0; i < trainTotal; i++) { //将train2id.txt中的三列数据,分别保存到trainList中 tmp = fscanf(fin, "%ld", &trainList[i].h); tmp = fscanf(fin, "%ld", &trainList[i].t); tmp = fscanf(fin, "%ld", &trainList[i].r); } fclose(fin); //按照头实体ID的大小,对trainList进行排序,若头实体ID相等,则判断关系ID;若头实体、关系都相等,则判断尾实体ID;并以升序的方式排列 std::sort(trainList, trainList + trainTotal, Triple::cmp_head); tmp = trainTotal; trainTotal = 1; trainHead[0] = trainTail[0] = trainRel[0] = trainList[0]; freqEnt[trainList[0].t] += 1; //以trainList[0]的尾实体作为数组freqEnt的下标,对应的值+1 freqEnt[trainList[0].h] += 1; //以trainList[0]的头实体作为数组freqEnt的下标,对应的值+1 freqRel[trainList[0].r] += 1; //以trainList[0]的关系作为数组freqEnt的下标,对应的值+1 //从i=1到train2id.txt中总的训练行数,遍历trainList for (INT i = 1; i < tmp; i++) //如果第i的一个的头实体不与i-1的头实体相等,或者i的关系不与i-1对应的关系相等,或者i的尾实体不与i-1的尾实体相等 //即,第i的一个训练三元组不与第i-1的训练三元组相同 if (trainList[i].h != trainList[i - 1].h || trainList[i].r != trainList[i - 1].r || trainList[i].t != trainList[i - 1].t) { //排除相邻且相同的三元组后,剩下不重复的训练三元组 trainHead[trainTotal] = trainTail[trainTotal] = trainRel[trainTotal] = trainList[trainTotal] = trainList[i]; trainTotal++; freqEnt[trainList[i].t]++; //以trainList[i]的尾实体作为数组freqEnt的下标,对应的值+1 freqEnt[trainList[i].h]++; //以trainList[i]的头实体作为数组freqEnt的下标,对应的值+1 freqRel[trainList[i].r]++; //以trainList[i]的关系作为数组freqEnt的下标,对应的值+1 } //按照头实体的大小,对trainHead进行排序,以升序的方式,若头实体ID相等,则判断关系ID;若头实体、关系都相等,则判断尾实体ID; std::sort(trainHead, trainHead + trainTotal, Triple::cmp_head); //按照尾实体的大小,对trainTail进行排序,以升序的方式,若尾实体ID相等,则判断关系ID;若尾实体、关系都相等,则判断尾实体ID; std::sort(trainTail, trainTail + trainTotal, Triple::cmp_tail); //按照头实体的大小,对trainRel进行排序,以升序的方式,若头实体ID相等,则判断尾实体;若头实体、尾实体都相等,则判断关系ID; std::sort(trainRel, trainRel + trainTotal, Triple::cmp_rel); printf("The total of train triples is %ld.\n", trainTotal); //以实体总数,分配内存空间给lefHead、lefHead、lefTail、rigTail、lefRel、rigRel lefHead = (INT *)calloc(entityTotal, sizeof(INT)); lefHead = (INT *)calloc(entityTotal, sizeof(INT)); lefTail = (INT *)calloc(entityTotal, sizeof(INT)); rigTail = (INT *)calloc(entityTotal, sizeof(INT)); lefRel = (INT *)calloc(entityTotal, sizeof(INT)); rigRel = (INT *)calloc(entityTotal, sizeof(INT)); //对数组rigHead、rigTail、rigRel初始化为-1 memset(rigHead, -1, sizeof(INT)*entityTotal); memset(rigTail, -1, sizeof(INT)*entityTotal); memset(rigRel, -1, sizeof(INT)*entityTotal); //从i=1,到trainTotal //ritTail保存的是尾实体ID较小的对应的trainT下标 //lefTail保存的是尾实体ID较大的对应的trainT下标 //rigHead、lefHead、rigRel、lefRel同理 for (INT i = 1; i < trainTotal; i++) { //如果trainTail,第i中的尾实体与i-1中的尾实体不一样 //即,如果相邻两个训练尾实体不相同,则以前者尾实体为rigTail的下标,将i-1替换对应的-1 // 将后者尾实体为lefTail的下标,将i替换对应的-1 if (trainTail[i].t != trainTail[i - 1].t) { rigTail[trainTail[i - 1].t] = i - 1; //将i-1赋值给以trainTail[i-1]的尾实体为下标,对应的rigTail值-1 lefTail[trainTail[i].t] = i; //将i赋值给以trainTail[i]的尾实体为下标,对应的lefTail值-1 } if (trainHead[i].h != trainHead[i - 1].h) { rigHead[trainHead[i - 1].h] = i - 1; lefHead[trainHead[i].h] = i; } if (trainRel[i].h != trainRel[i - 1].h) { rigRel[trainRel[i - 1].h] = i - 1; lefRel[trainRel[i].h] = i; } } //将以0作为下标的值赋值为0,以及以训练集的最后一位作为下标,赋值为trainTotal-1 lefHead[trainHead[0].h] = 0; rigHead[trainHead[trainTotal - 1].h] = trainTotal - 1; lefTail[trainTail[0].t] = 0; rigTail[trainTail[trainTotal - 1].t] = trainTotal - 1; lefRel[trainRel[0].h] = 0; rigRel[trainRel[trainTotal - 1].h] = trainTotal - 1; //为left_mean、right_mean分配实数型的内存,元素个数为relationTotal,大小为REAL left_mean = (REAL *)calloc(relationTotal,sizeof(REAL)); right_mean = (REAL *)calloc(relationTotal,sizeof(REAL)); for (INT i = 0; i < entityTotal; i++) { for (INT j = lefHead[i] + 1; j <= rigHead[i]; j++) if (trainHead[j].r != trainHead[j - 1].r) left_mean[trainHead[j].r] += 1.0; //相邻训练头实体对应的关系不等情况下,对头实体的出边+1 if (lefHead[i] <= rigHead[i]) left_mean[trainHead[lefHead[i]].r] += 1.0; //如果左实体的大小小于等于右实体的大小,则以左实体对应的出边+1 for (INT j = lefTail[i] + 1; j <= rigTail[i]; j++) if (trainTail[j].r != trainTail[j - 1].r) right_mean[trainTail[j].r] += 1.0; if (lefTail[i] <= rigTail[i]) right_mean[trainTail[lefTail[i]].r] += 1.0; } for (INT i = 0; i < relationTotal; i++) { left_mean[i] = freqRel[i] / left_mean[i]; //实体的个数除以对应实体的出边 right_mean[i] = freqRel[i] / right_mean[i]; //实体的个数除以对应实体的入边 } } Triple *testList; Triple *validList; Triple *tripleList; extern "C" void importTestFiles() { FILE *fin; INT tmp; if (rel_file == "") fin = fopen((inPath + "relation2id.txt").c_str(), "r"); else fin = fopen(rel_file.c_str(), "r"); tmp = fscanf(fin, "%ld", &relationTotal); fclose(fin); if (ent_file == "") fin = fopen((inPath + "entity2id.txt").c_str(), "r"); else fin = fopen(ent_file.c_str(), "r"); tmp = fscanf(fin, "%ld", &entityTotal); fclose(fin); FILE* f_kb1, * f_kb2, * f_kb3; if (train_file == "") f_kb2 = fopen((inPath + "train2id.txt").c_str(), "r"); else f_kb2 = fopen(train_file.c_str(), "r"); if (test_file == "") f_kb1 = fopen((inPath + "test2id.txt").c_str(), "r"); else f_kb1 = fopen(test_file.c_str(), "r"); if (valid_file == "") f_kb3 = fopen((inPath + "valid2id.txt").c_str(), "r"); else f_kb3 = fopen(valid_file.c_str(), "r"); tmp = fscanf(f_kb1, "%ld", &testTotal); tmp = fscanf(f_kb2, "%ld", &trainTotal); tmp = fscanf(f_kb3, "%ld", &validTotal); tripleTotal = testTotal + trainTotal + validTotal; testList = (Triple *)calloc(testTotal, sizeof(Triple)); validList = (Triple *)calloc(validTotal, sizeof(Triple)); tripleList = (Triple *)calloc(tripleTotal, sizeof(Triple)); for (INT i = 0; i < testTotal; i++) { tmp = fscanf(f_kb1, "%ld", &testList[i].h); tmp = fscanf(f_kb1, "%ld", &testList[i].t); tmp = fscanf(f_kb1, "%ld", &testList[i].r); tripleList[i] = testList[i]; } for (INT i = 0; i < trainTotal; i++) { tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].h); tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].t); tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].r); } for (INT i = 0; i < validTotal; i++) { tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].h); tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].t); tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].r); validList[i] = tripleList[i + testTotal + trainTotal]; } fclose(f_kb1); fclose(f_kb2); fclose(f_kb3); std::sort(tripleList, tripleList + tripleTotal, Triple::cmp_head); std::sort(testList, testList + testTotal, Triple::cmp_rel2); std::sort(validList, validList + validTotal, Triple::cmp_rel2); printf("The total of test triples is %ld.\n", testTotal); printf("The total of valid triples is %ld.\n", validTotal); testLef = (INT *)calloc(relationTotal, sizeof(INT)); testRig = (INT *)calloc(relationTotal, sizeof(INT)); memset(testLef, -1, sizeof(INT) * relationTotal); memset(testRig, -1, sizeof(INT) * relationTotal); for (INT i = 1; i < testTotal; i++) { if (testList[i].r != testList[i-1].r) { testRig[testList[i-1].r] = i - 1; testLef[testList[i].r] = i; } } testLef[testList[0].r] = 0; testRig[testList[testTotal - 1].r] = testTotal - 1; validLef = (INT *)calloc(relationTotal, sizeof(INT)); validRig = (INT *)calloc(relationTotal, sizeof(INT)); memset(validLef, -1, sizeof(INT)*relationTotal); memset(validRig, -1, sizeof(INT)*relationTotal); for (INT i = 1; i < validTotal; i++) { if (validList[i].r != validList[i-1].r) { validRig[validList[i-1].r] = i - 1; validLef[validList[i].r] = i; } } validLef[validList[0].r] = 0; validRig[validList[validTotal - 1].r] = validTotal - 1; } INT* head_lef; INT* head_rig; INT* tail_lef; INT* tail_rig; INT* head_type; INT* tail_type; extern "C" void importTypeFiles() { head_lef = (INT *)calloc(relationTotal, sizeof(INT)); head_rig = (INT *)calloc(relationTotal, sizeof(INT)); tail_lef = (INT *)calloc(relationTotal, sizeof(INT)); tail_rig = (INT *)calloc(relationTotal, sizeof(INT)); INT total_lef = 0; INT total_rig = 0; FILE* f_type = fopen((inPath + "type_constrain.txt").c_str(),"r"); INT tmp; tmp = fscanf(f_type, "%ld", &tmp); for (INT i = 0; i < relationTotal; i++) { INT rel, tot; tmp = fscanf(f_type, "%ld %ld", &rel, &tot); for (INT j = 0; j < tot; j++) { tmp = fscanf(f_type, "%ld", &tmp); total_lef++; } tmp = fscanf(f_type, "%ld%ld", &rel, &tot); for (INT j = 0; j < tot; j++) { tmp = fscanf(f_type, "%ld", &tmp); total_rig++; } } fclose(f_type); head_type = (INT *)calloc(total_lef, sizeof(INT)); tail_type = (INT *)calloc(total_rig, sizeof(INT)); total_lef = 0; total_rig = 0; f_type = fopen((inPath + "type_constrain.txt").c_str(),"r"); tmp = fscanf(f_type, "%ld", &tmp); for (INT i = 0; i < relationTotal; i++) { INT rel, tot; tmp = fscanf(f_type, "%ld%ld", &rel, &tot); head_lef[rel] = total_lef; for (INT j = 0; j < tot; j++) { tmp = fscanf(f_type, "%ld", &head_type[total_lef]); total_lef++; } head_rig[rel] = total_lef; std::sort(head_type + head_lef[rel], head_type + head_rig[rel]); tmp = fscanf(f_type, "%ld%ld", &rel, &tot); tail_lef[rel] = total_rig; for (INT j = 0; j < tot; j++) { tmp = fscanf(f_type, "%ld", &tail_type[total_rig]); total_rig++; } tail_rig[rel] = total_rig; std::sort(tail_type + tail_lef[rel], tail_type + tail_rig[rel]); } fclose(f_type); } #endif
#ifndef SETTING_H #define SETTING_H #define INT long #define REAL float #include <cstring> #include <cstdio> #include <string> std::string inPath = "../data/FB15K/"; std::string outPath = "../data/FB15K/"; std::string ent_file = ""; std::string rel_file = ""; std::string train_file = ""; std::string valid_file = ""; std::string test_file = ""; //指示编译器这部分代码按C语言语法进行编译,而不是C++的 //主要作用就是为了能够正确实现C++代码调用其他C语言代码 //extern 是变量或函数的申明,告诉编译器在其它文件中找这个变量或函数的定义。 extern "C" void setInPath(char *path) { INT len = strlen(path); inPath = ""; for (INT i = 0; i < len; i++) inPath = inPath + path[i]; printf("Input Files Path : %s\n", inPath.c_str()); } extern "C" void setOutPath(char *path) { INT len = strlen(path); outPath = ""; for (INT i = 0; i < len; i++) outPath = outPath + path[i]; printf("Output Files Path : %s\n", outPath.c_str()); } extern "C" void setTrainPath(char *path) { INT len = strlen(path); train_file = ""; for (INT i = 0; i < len; i++) train_file = train_file + path[i]; printf("Training Files Path : %s\n", train_file.c_str()); } extern "C" void setValidPath(char *path) { INT len = strlen(path); valid_file = ""; for (INT i = 0; i < len; i++) valid_file = valid_file + path[i]; printf("Valid Files Path : %s\n", valid_file.c_str()); } extern "C" void setTestPath(char *path) { INT len = strlen(path); test_file = ""; for (INT i = 0; i < len; i++) test_file = test_file + path[i]; printf("Test Files Path : %s\n", test_file.c_str()); } extern "C" void setEntPath(char *path) { INT len = strlen(path); ent_file = ""; for (INT i = 0; i < len; i++) ent_file = ent_file + path[i]; printf("Entity Files Path : %s\n", ent_file.c_str()); } extern "C" void setRelPath(char *path) { INT len = strlen(path); rel_file = ""; for (INT i = 0; i < len; i++) rel_file = rel_file + path[i]; printf("Relation Files Path : %s\n", rel_file.c_str()); } /* ============================================================ */ INT workThreads = 1; extern "C" void setWorkThreads(INT threads) { workThreads = threads; } extern "C" INT getWorkThreads() { return workThreads; } /* ============================================================ */ INT relationTotal = 0; INT entityTotal = 0; INT tripleTotal = 0; INT testTotal = 0; INT trainTotal = 0; INT validTotal = 0; extern "C" INT getEntityTotal() { return entityTotal; } extern "C" INT getRelationTotal() { return relationTotal; } extern "C" INT getTripleTotal() { return tripleTotal; } extern "C" INT getTrainTotal() { return trainTotal; } extern "C" INT getTestTotal() { return testTotal; } extern "C" INT getValidTotal() { return validTotal; } /* ============================================================ */ INT bernFlag = 0; extern "C" void setBern(INT con) { bernFlag = con; } #endif
#ifndef TEST_H #define TEST_H #include "Setting.h" #include "Reader.h" #include "Corrupt.h" /*===================================================================================== link prediction ======================================================================================*/ INT lastHead = 0; INT lastTail = 0; INT lastRel = 0; REAL l1_filter_tot = 0, l1_tot = 0, r1_tot = 0, r1_filter_tot = 0, l_tot = 0, r_tot = 0, l_filter_rank = 0, l_rank = 0, l_filter_reci_rank = 0, l_reci_rank = 0; REAL l3_filter_tot = 0, l3_tot = 0, r3_tot = 0, r3_filter_tot = 0, l_filter_tot = 0, r_filter_tot = 0, r_filter_rank = 0, r_rank = 0, r_filter_reci_rank = 0, r_reci_rank = 0; REAL rel3_tot = 0, rel3_filter_tot = 0, rel_filter_tot = 0, rel_filter_rank = 0, rel_rank = 0, rel_filter_reci_rank = 0, rel_reci_rank = 0, rel_tot = 0, rel1_tot = 0, rel1_filter_tot = 0; REAL l1_filter_tot_constrain = 0, l1_tot_constrain = 0, r1_tot_constrain = 0, r1_filter_tot_constrain = 0, l_tot_constrain = 0, r_tot_constrain = 0, l_filter_rank_constrain = 0, l_rank_constrain = 0, l_filter_reci_rank_constrain = 0, l_reci_rank_constrain = 0; REAL l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0, r3_filter_tot_constrain = 0, l_filter_tot_constrain = 0, r_filter_tot_constrain = 0, r_filter_rank_constrain = 0, r_rank_constrain = 0, r_filter_reci_rank_constrain = 0, r_reci_rank_constrain = 0; REAL hit1, hit3, hit10, mr, mrr; REAL hit1TC, hit3TC, hit10TC, mrTC, mrrTC; extern "C" void initTest() { lastHead = 0; lastTail = 0; lastRel = 0; l1_filter_tot = 0, l1_tot = 0, r1_tot = 0, r1_filter_tot = 0, l_tot = 0, r_tot = 0, l_filter_rank = 0, l_rank = 0, l_filter_reci_rank = 0, l_reci_rank = 0; l3_filter_tot = 0, l3_tot = 0, r3_tot = 0, r3_filter_tot = 0, l_filter_tot = 0, r_filter_tot = 0, r_filter_rank = 0, r_rank = 0, r_filter_reci_rank = 0, r_reci_rank = 0; REAL rel3_tot = 0, rel3_filter_tot = 0, rel_filter_tot = 0, rel_filter_rank = 0, rel_rank = 0, rel_filter_reci_rank = 0, rel_reci_rank = 0, rel_tot = 0, rel1_tot = 0, rel1_filter_tot = 0; l1_filter_tot_constrain = 0, l1_tot_constrain = 0, r1_tot_constrain = 0, r1_filter_tot_constrain = 0, l_tot_constrain = 0, r_tot_constrain = 0, l_filter_rank_constrain = 0, l_rank_constrain = 0, l_filter_reci_rank_constrain = 0, l_reci_rank_constrain = 0; l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0, r3_filter_tot_constrain = 0, l_filter_tot_constrain = 0, r_filter_tot_constrain = 0, r_filter_rank_constrain = 0, r_rank_constrain = 0, r_filter_reci_rank_constrain = 0, r_reci_rank_constrain = 0; } extern "C" void getHeadBatch(INT *ph, INT *pt, INT *pr) { for (INT i = 0; i < entityTotal; i++) { ph[i] = i; pt[i] = testList[lastHead].t; pr[i] = testList[lastHead].r; } lastHead++; } extern "C" void getTailBatch(INT *ph, INT *pt, INT *pr) { for (INT i = 0; i < entityTotal; i++) { ph[i] = testList[lastTail].h; pt[i] = i; pr[i] = testList[lastTail].r; } lastTail++; } extern "C" void getRelBatch(INT *ph, INT *pt, INT *pr) { for (INT i = 0; i < relationTotal; i++) { ph[i] = testList[lastRel].h; pt[i] = testList[lastRel].t; pr[i] = i; } } extern "C" void testHead(REAL *con, INT lastHead, bool type_constrain = false) { INT h = testList[lastHead].h; INT t = testList[lastHead].t; INT r = testList[lastHead].r; INT lef, rig; if (type_constrain) { lef = head_lef[r]; rig = head_rig[r]; } REAL minimal = con[h]; INT l_s = 0; INT l_filter_s = 0; INT l_s_constrain = 0; INT l_filter_s_constrain = 0; for (INT j = 0; j < entityTotal; j++) { if (j != h) { REAL value = con[j]; if (value < minimal) { l_s += 1; if (not _find(j, t, r)) l_filter_s += 1; } if (type_constrain) { while (lef < rig && head_type[lef] < j) lef ++; if (lef < rig && j == head_type[lef]) { if (value < minimal) { l_s_constrain += 1; if (not _find(j, t, r)) { l_filter_s_constrain += 1; } } } } } } if (l_filter_s < 10) l_filter_tot += 1; if (l_s < 10) l_tot += 1; if (l_filter_s < 3) l3_filter_tot += 1; if (l_s < 3) l3_tot += 1; if (l_filter_s < 1) l1_filter_tot += 1; if (l_s < 1) l1_tot += 1; l_filter_rank += (l_filter_s+1); l_rank += (1 + l_s); l_filter_reci_rank += 1.0/(l_filter_s+1); l_reci_rank += 1.0/(l_s+1); if (type_constrain) { if (l_filter_s_constrain < 10) l_filter_tot_constrain += 1; if (l_s_constrain < 10) l_tot_constrain += 1; if (l_filter_s_constrain < 3) l3_filter_tot_constrain += 1; if (l_s_constrain < 3) l3_tot_constrain += 1; if (l_filter_s_constrain < 1) l1_filter_tot_constrain += 1; if (l_s_constrain < 1) l1_tot_constrain += 1; l_filter_rank_constrain += (l_filter_s_constrain+1); l_rank_constrain += (1+l_s_constrain); l_filter_reci_rank_constrain += 1.0/(l_filter_s_constrain+1); l_reci_rank_constrain += 1.0/(l_s_constrain+1); } } extern "C" void testTail(REAL *con, INT lastTail, bool type_constrain = false) { INT h = testList[lastTail].h; INT t = testList[lastTail].t; INT r = testList[lastTail].r; INT lef, rig; if (type_constrain) { lef = tail_lef[r]; rig = tail_rig[r]; } REAL minimal = con[t]; INT r_s = 0; INT r_filter_s = 0; INT r_s_constrain = 0; INT r_filter_s_constrain = 0; for (INT j = 0; j < entityTotal; j++) { if (j != t) { REAL value = con[j]; if (value < minimal) { r_s += 1; if (not _find(h, j, r)) r_filter_s += 1; } if (type_constrain) { while (lef < rig && tail_type[lef] < j) lef ++; if (lef < rig && j == tail_type[lef]) { if (value < minimal) { r_s_constrain += 1; if (not _find(h, j ,r)) { r_filter_s_constrain += 1; } } } } } } if (r_filter_s < 10) r_filter_tot += 1; if (r_s < 10) r_tot += 1; if (r_filter_s < 3) r3_filter_tot += 1; if (r_s < 3) r3_tot += 1; if (r_filter_s < 1) r1_filter_tot += 1; if (r_s < 1) r1_tot += 1; r_filter_rank += (1+r_filter_s); r_rank += (1+r_s); r_filter_reci_rank += 1.0/(1+r_filter_s); r_reci_rank += 1.0/(1+r_s); if (type_constrain) { if (r_filter_s_constrain < 10) r_filter_tot_constrain += 1; if (r_s_constrain < 10) r_tot_constrain += 1; if (r_filter_s_constrain < 3) r3_filter_tot_constrain += 1; if (r_s_constrain < 3) r3_tot_constrain += 1; if (r_filter_s_constrain < 1) r1_filter_tot_constrain += 1; if (r_s_constrain < 1) r1_tot_constrain += 1; r_filter_rank_constrain += (1+r_filter_s_constrain); r_rank_constrain += (1+r_s_constrain); r_filter_reci_rank_constrain += 1.0/(1+r_filter_s_constrain); r_reci_rank_constrain += 1.0/(1+r_s_constrain); } } extern "C" void testRel(REAL *con) { INT h = testList[lastRel].h; INT t = testList[lastRel].t; INT r = testList[lastRel].r; REAL minimal = con[r]; INT rel_s = 0; INT rel_filter_s = 0; for (INT j = 0; j < relationTotal; j++) { if (j != r) { REAL value = con[j]; if (value < minimal) { rel_s += 1; if (not _find(h, t, j)) rel_filter_s += 1; } } } if (rel_filter_s < 10) rel_filter_tot += 1; if (rel_s < 10) rel_tot += 1; if (rel_filter_s < 3) rel3_filter_tot += 1; if (rel_s < 3) rel3_tot += 1; if (rel_filter_s < 1) rel1_filter_tot += 1; if (rel_s < 1) rel1_tot += 1; rel_filter_rank += (rel_filter_s+1); rel_rank += (1+rel_s); rel_filter_reci_rank += 1.0/(rel_filter_s+1); rel_reci_rank += 1.0/(rel_s+1); lastRel++; } extern "C" void test_link_prediction(bool type_constrain = false) { l_rank /= testTotal; r_rank /= testTotal; l_reci_rank /= testTotal; r_reci_rank /= testTotal; l_tot /= testTotal; l3_tot /= testTotal; l1_tot /= testTotal; r_tot /= testTotal; r3_tot /= testTotal; r1_tot /= testTotal; // with filter l_filter_rank /= testTotal; r_filter_rank /= testTotal; l_filter_reci_rank /= testTotal; r_filter_reci_rank /= testTotal; l_filter_tot /= testTotal; l3_filter_tot /= testTotal; l1_filter_tot /= testTotal; r_filter_tot /= testTotal; r3_filter_tot /= testTotal; r1_filter_tot /= testTotal; printf("no type constraint results:\n"); printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n"); printf("l(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", l_reci_rank, l_rank, l_tot, l3_tot, l1_tot); printf("r(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", r_reci_rank, r_rank, r_tot, r3_tot, r1_tot); printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n", (l_reci_rank+r_reci_rank)/2, (l_rank+r_rank)/2, (l_tot+r_tot)/2, (l3_tot+r3_tot)/2, (l1_tot+r1_tot)/2); printf("\n"); printf("l(filter):\t\t %f \t %f \t %f \t %f \t %f \n", l_filter_reci_rank, l_filter_rank, l_filter_tot, l3_filter_tot, l1_filter_tot); printf("r(filter):\t\t %f \t %f \t %f \t %f \t %f \n", r_filter_reci_rank, r_filter_rank, r_filter_tot, r3_filter_tot, r1_filter_tot); printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n", (l_filter_reci_rank+r_filter_reci_rank)/2, (l_filter_rank+r_filter_rank)/2, (l_filter_tot+r_filter_tot)/2, (l3_filter_tot+r3_filter_tot)/2, (l1_filter_tot+r1_filter_tot)/2); mrr = (l_filter_reci_rank+r_filter_reci_rank) / 2; mr = (l_filter_rank+r_filter_rank) / 2; hit10 = (l_filter_tot+r_filter_tot) / 2; hit3 = (l3_filter_tot+r3_filter_tot) / 2; hit1 = (l1_filter_tot+r1_filter_tot) / 2; if (type_constrain) { //type constrain l_rank_constrain /= testTotal; r_rank_constrain /= testTotal; l_reci_rank_constrain /= testTotal; r_reci_rank_constrain /= testTotal; l_tot_constrain /= testTotal; l3_tot_constrain /= testTotal; l1_tot_constrain /= testTotal; r_tot_constrain /= testTotal; r3_tot_constrain /= testTotal; r1_tot_constrain /= testTotal; // with filter l_filter_rank_constrain /= testTotal; r_filter_rank_constrain /= testTotal; l_filter_reci_rank_constrain /= testTotal; r_filter_reci_rank_constrain /= testTotal; l_filter_tot_constrain /= testTotal; l3_filter_tot_constrain /= testTotal; l1_filter_tot_constrain /= testTotal; r_filter_tot_constrain /= testTotal; r3_filter_tot_constrain /= testTotal; r1_filter_tot_constrain /= testTotal; printf("type constraint results:\n"); printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n"); printf("l(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", l_reci_rank_constrain, l_rank_constrain, l_tot_constrain, l3_tot_constrain, l1_tot_constrain); printf("r(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", r_reci_rank_constrain, r_rank_constrain, r_tot_constrain, r3_tot_constrain, r1_tot_constrain); printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n", (l_reci_rank_constrain+r_reci_rank_constrain)/2, (l_rank_constrain+r_rank_constrain)/2, (l_tot_constrain+r_tot_constrain)/2, (l3_tot_constrain+r3_tot_constrain)/2, (l1_tot_constrain+r1_tot_constrain)/2); printf("\n"); printf("l(filter):\t\t %f \t %f \t %f \t %f \t %f \n", l_filter_reci_rank_constrain, l_filter_rank_constrain, l_filter_tot_constrain, l3_filter_tot_constrain, l1_filter_tot_constrain); printf("r(filter):\t\t %f \t %f \t %f \t %f \t %f \n", r_filter_reci_rank_constrain, r_filter_rank_constrain, r_filter_tot_constrain, r3_filter_tot_constrain, r1_filter_tot_constrain); printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n", (l_filter_reci_rank_constrain+r_filter_reci_rank_constrain)/2, (l_filter_rank_constrain+r_filter_rank_constrain)/2, (l_filter_tot_constrain+r_filter_tot_constrain)/2, (l3_filter_tot_constrain+r3_filter_tot_constrain)/2, (l1_filter_tot_constrain+r1_filter_tot_constrain)/2); mrrTC = (l_filter_reci_rank_constrain+r_filter_reci_rank_constrain)/2; mrTC = (l_filter_rank_constrain+r_filter_rank_constrain) / 2; hit10TC = (l_filter_tot_constrain+r_filter_tot_constrain) / 2; hit3TC = (l3_filter_tot_constrain+r3_filter_tot_constrain) / 2; hit1TC = (l1_filter_tot_constrain+r1_filter_tot_constrain) / 2; } } extern "C" void test_relation_prediction() { rel_rank /= testTotal; rel_reci_rank /= testTotal; rel_tot /= testTotal; rel3_tot /= testTotal; rel1_tot /= testTotal; // with filter rel_filter_rank /= testTotal; rel_filter_reci_rank /= testTotal; rel_filter_tot /= testTotal; rel3_filter_tot /= testTotal; rel1_filter_tot /= testTotal; printf("no type constraint results:\n"); printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n"); printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n", rel_reci_rank, rel_rank, rel_tot, rel3_tot, rel1_tot); printf("\n"); printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n", rel_filter_reci_rank, rel_filter_rank, rel_filter_tot, rel3_filter_tot, rel1_filter_tot); } extern "C" REAL getTestLinkHit10(bool type_constrain = false) { if (type_constrain) return hit10TC; printf("%f\n", hit10); return hit10; } extern "C" REAL getTestLinkHit3(bool type_constrain = false) { if (type_constrain) return hit3TC; return hit3; } extern "C" REAL getTestLinkHit1(bool type_constrain = false) { if (type_constrain) return hit1TC; return hit1; } extern "C" REAL getTestLinkMR(bool type_constrain = false) { if (type_constrain) return mrTC; return mr; } extern "C" REAL getTestLinkMRR(bool type_constrain = false) { if (type_constrain) return mrrTC; return mrr; } /*===================================================================================== triple classification ======================================================================================*/ Triple *negTestList = NULL; extern "C" void getNegTest() { if (negTestList == NULL) negTestList = (Triple *)calloc(testTotal, sizeof(Triple)); for (INT i = 0; i < testTotal; i++) { negTestList[i] = testList[i]; if (randd(0) % 1000 < 500) negTestList[i].t = corrupt_head(0, testList[i].h, testList[i].r); else negTestList[i].h = corrupt_tail(0, testList[i].t, testList[i].r); } } extern "C" void getTestBatch(INT *ph, INT *pt, INT *pr, INT *nh, INT *nt, INT *nr) { getNegTest(); for (INT i = 0; i < testTotal; i++) { ph[i] = testList[i].h; pt[i] = testList[i].t; pr[i] = testList[i].r; nh[i] = negTestList[i].h; nt[i] = negTestList[i].t; nr[i] = negTestList[i].r; } } #endif
#ifndef TRIPLE_H #define TRIPLE_H #include "Setting.h" struct Triple { INT h, r, t; static bool cmp_head(const Triple &a, const Triple &b) { return (a.h < b.h)||(a.h == b.h && a.r < b.r)||(a.h == b.h && a.r == b.r && a.t < b.t); } static bool cmp_tail(const Triple &a, const Triple &b) { return (a.t < b.t)||(a.t == b.t && a.r < b.r)||(a.t == b.t && a.r == b.r && a.h < b.h); } static bool cmp_rel(const Triple &a, const Triple &b) { return (a.h < b.h)||(a.h == b.h && a.t < b.t)||(a.h == b.h && a.t == b.t && a.r < b.r); } static bool cmp_rel2(const Triple &a, const Triple &b) { return (a.r < b.r)||(a.r == b.r && a.h < b.h)||(a.r == b.r && a.h == b.h && a.t < b.t); } }; #endif
TransE模型的定义,在train_transe_FB15K237.py中
# define the model
transe = TransE(
ent_tot = train_dataloader.get_ent_tot(), #获取实体总数
rel_tot = train_dataloader.get_rel_tot(), #获取关系总数
dim = 200, #向量维度
p_norm = 1,
norm_flag = True)
损失函数的定义,在train_transe_FB15K237.py中
# define the loss function
model = NegativeSampling(
model = transe,
loss = MarginLoss(margin = 5.0),
batch_size = train_dataloader.get_batch_size() #获取batchSize
)
模型的训练,在train_transe_FB15K237.py中
# train the model
'''
model: 模型,这里是transe
data_loader: 数据加载
train_times: 训练次数
alpha: 学习率,在优化理论中,学习率也叫步长。在梯度下降算法中,步长决定了每一次迭代过程中,会往梯度下降的方向移动的距离。如果步长很大,算法会在局部最优点附近来回跳动,不会收敛;但如果步长太短,算法每步的移动距离很短,就会导致算法收敛速度很慢。
use_gpu: 是否使用GPU
'''
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True)
trainer.run()
transe.save_checkpoint('./checkpoint/transe.ckpt')
参考文章
[1]: https://www.codetd.com/article/7778596
[2]: http://139.129.163.161/index/toolkits#openke
[3]: https://github.com/thunlp/OpenKE
[4]: http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。