赞
踩
本文以项目readme.md训练逻辑的顺序解读
更多bert模型参考github地址
本文用的是BERT-Base, Cased(12-layer, 768-hidden, 12-heads , 110M parameters)下载地址。其中Cased表示保留真实的大小写和重音标记符,uncased表示文本在单词标记之前就已经变为小写,也去掉了任何重音标记例如,John Smith变成john smith。通常,Uncased模型会更好,除非大小写信息对于我们的任务很重要,如命名实体识别或词性标记。
以NYT数据集为例:NYT数据集是关于远程监督关系抽取任务的广泛使用的数据集。该数据集是通过将freebase中的关系与纽约时报(NYT)语料库对齐而生成的。纽约时报New York Times数据集包含150篇来自纽约时报的商业文章。抓取了从2009年11月到2010年1月纽约时报网站上的所有文章。在句子拆分和标记化之后,使用斯坦福NER标记器来标识PER和ORG从每个句子中的命名实体。对于包含多个标记的命名实体,我们将它们连接成单个标记。然后,我们将同一句子中出现的每一对(PER,ORG)实体作为单个候选关系实例,PER实体被视为ARG-1,ORG实体被视为ARG-2。
generate.py
# 将raw_NYT\train.json中的数字形式生成训练集的文本形式 def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file): with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, \ open(epo_file,'w') as f4, open(seo_file, 'w') as f5: seo_file, 'w') as f5: cnt_normal = 0 cnt_epo = 0 cnt_seo = 0 lines = f1.readlines() # readlines()方法用于读取所有行(直到结束符EOF)并返回列表 for line in lines: line = json.loads(line) print(len(line)) lengths, sents, spos = line[0], line[1], line[2] print(len(spos)) print(len(sents)) for i in range(len(sents)): new_line = dict() # print(sents[i]) # print(spos[i]) tokens = [word_dict[i] for i in sents[i]] # tokens为sents对应数字形式的字符串数组 sent = ' '.join(tokens) # 以空格形式连接字符串数组生成一个新的字符串 new_line['sentText'] = sent # new_line为包含三元组的字典 triples = np.reshape(spos[i], (-1, 3)) # 将spo[i]关系三元组的维度变为3列 relationMentions = [] for triple in triples: rel = dict() rel['em1Text'] = tokens[triple[0]] rel['em2Text'] = tokens[triple[1]] rel['label'] = rel_dict[triple[2]] relationMentions.append(rel) new_line['relationMentions'] = relationMentions f2.write(json.dumps(new_line) + '\n') if is_normal_triple(spos[i]): f3.write(json.dumps(new_line) + '\n') if is_multi_label(spos[i]): f4.write(json.dumps(new_line) + '\n') if is_over_lapping(spos[i]): f5.write(json.dumps(new_line) + '\n')
build_data.py
# 读取数据集文件,将文本、三元组分类存储 with open('train.json') as f: for l in tqdm(f): # tqdm是可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator) a = json.loads(l) if not a['relationMentions']: # 若某个句子a中关系'relationMentions'为空,跳过之 continue # 提取出每个句子及其三元组 line = { 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), #去除'sentText'中的\r(回车)、\n(换行)、两头的'\' 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] } if not line['triple_list']: continue # 将提取出来的句子及其三元组信息加入到训练集数据train_data中,将三元组中的关系加入到集合rel_set中(无序不重复元素序列) train_data.append(line) for rm in a['relationMentions']: if rm['label'] != 'None': rel_set.add(rm['label'])
run.py中的默认参数:
{
"bert_model": "cased_L-12_H-768_A-12",
"max_len": 100,
"learning_rate": 1e-5,
"batch_size": 6,
"epoch_num": 100,
}
根据自己的设置修改
确定运行方式,使用的数据集
python run.py ---train=True --dataset=NYT
在测试集上评估
python run.py --dataset=NYT
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。