当前位置:   article > 正文

哈工大关系抽取模型CasRel代码解读_casrel代码详解

casrel代码详解

本文以项目readme.md训练逻辑的顺序解读

1.下载BERT预训练模型

      更多bert模型参考github地址
      本文用的是BERT-Base, Cased(12-layer, 768-hidden, 12-heads , 110M parameters)下载地址。其中Cased表示保留真实的大小写和重音标记符,uncased表示文本在单词标记之前就已经变为小写,也去掉了任何重音标记例如,John Smith变成john smith。通常,Uncased模型会更好,除非大小写信息对于我们的任务很重要,如命名实体识别或词性标记。

2.以三元组的形式构造数据集

      以NYT数据集为例:NYT数据集是关于远程监督关系抽取任务的广泛使用的数据集。该数据集是通过将freebase中的关系与纽约时报(NYT)语料库对齐而生成的。纽约时报New York Times数据集包含150篇来自纽约时报的商业文章。抓取了从2009年11月到2010年1月纽约时报网站上的所有文章。在句子拆分和标记化之后,使用斯坦福NER标记器来标识PER和ORG从每个句子中的命名实体。对于包含多个标记的命名实体,我们将它们连接成单个标记。然后,我们将同一句子中出现的每一对(PER,ORG)实体作为单个候选关系实例,PER实体被视为ARG-1,ORG实体被视为ARG-2。

2.1运行过程

  1. NYT数据集下载地址
  2. 运行CasRel/data/NYT/raw_NYT/generate.py将数字编码形式的nyt数据集转换为字符形式的数据集,并根据三元组将数据集分类为normal,epo,spo几种类型。运行结果为CasRel/data/NYT/new_train.json,new_train_epo.json,new_train_normal.json,new_train_seo.json。将产生的新文件移至NYT/目录下,将new_train.json改名为train.json,test与valid同理。
  3. 运行CasRel/data/NYT/build_data.py处理数据得到三元组,按train,dev,test分类。
  4. 将test.json移至test_split_by_num目录中,运行split_by_num.py将test集按每个句子含有的三元组数量分类;将test_epo.json,test_normal.json,test_seo.json移至test_split_by_type目录中,将test集按类型分类得到三元组文件。

2.2重要的代码

generate.py

# 将raw_NYT\train.json中的数字形式生成训练集的文本形式
def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file):
   with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, \
            open(epo_file,'w') as f4, open(seo_file, 'w') as f5:
        seo_file, 'w') as f5:
        cnt_normal = 0
        cnt_epo = 0
        cnt_seo = 0
        lines = f1.readlines()  # readlines()方法用于读取所有行(直到结束符EOF)并返回列表
        for line in lines:
            line = json.loads(line)
            print(len(line))
            lengths, sents, spos = line[0], line[1], line[2] 
            print(len(spos))
            print(len(sents))
            for i in range(len(sents)):
                new_line = dict()
                # print(sents[i])
                # print(spos[i])
                tokens = [word_dict[i] for i in sents[i]]  # tokens为sents对应数字形式的字符串数组
                sent = ' '.join(tokens)  # 以空格形式连接字符串数组生成一个新的字符串
                new_line['sentText'] = sent # new_line为包含三元组的字典
                triples = np.reshape(spos[i], (-1, 3))  # 将spo[i]关系三元组的维度变为3列
                relationMentions = []
                for triple in triples:
                    rel = dict()
                    rel['em1Text'] = tokens[triple[0]]
                    rel['em2Text'] = tokens[triple[1]]
                    rel['label'] = rel_dict[triple[2]]
                    relationMentions.append(rel)
                new_line['relationMentions'] = relationMentions
                f2.write(json.dumps(new_line) + '\n')
                if is_normal_triple(spos[i]):
                    f3.write(json.dumps(new_line) + '\n')
                if is_multi_label(spos[i]):
                    f4.write(json.dumps(new_line) + '\n')
                if is_over_lapping(spos[i]):
                    f5.write(json.dumps(new_line) + '\n')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

build_data.py

# 读取数据集文件,将文本、三元组分类存储
with open('train.json') as f:
    for l in tqdm(f):  # tqdm是可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)
        a = json.loads(l)
        if not a['relationMentions']:  # 若某个句子a中关系'relationMentions'为空,跳过之
            continue
        # 提取出每个句子及其三元组
        line = {
            'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), #去除'sentText'中的\r(回车)、\n(换行)、两头的'\'
            'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if
                            i['label'] != 'None']
        }
        if not line['triple_list']:
            continue
        # 将提取出来的句子及其三元组信息加入到训练集数据train_data中,将三元组中的关系加入到集合rel_set中(无序不重复元素序列)
        train_data.append(line)
        for rm in a['relationMentions']:
            if rm['label'] != 'None':
                rel_set.add(rm['label'])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

3.指定实验设置

run.py中的默认参数:

{
    "bert_model": "cased_L-12_H-768_A-12",
    "max_len": 100,
    "learning_rate": 1e-5,
    "batch_size": 6,
    "epoch_num": 100,
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

根据自己的设置修改

4.训练模型并评估

确定运行方式,使用的数据集

python run.py ---train=True --dataset=NYT
  • 1

在测试集上评估

python run.py --dataset=NYT
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/581834
推荐阅读
相关标签
  

闽ICP备14008679号