赞
踩
先上热菜致敬苏神:苏剑林. (2020, Jan 03). 《用bert4keras做三元组抽取 》[Blog post]. Retrieved from https://kexue.fm/archives/7161
建议大家先看苏神的原文,如果您能看懂思路和代码的话我的文章可能对你的帮助不大。
拜读这篇文章之后本人用TF + Transformers 复现了该baseline模型,并在其基础上进行了大量的尝试,直到心累也没有成功复现相同水平的结果,但也有所接近,因此用这篇文章复盘整个过程并分享一些收获和心得。
数据下载地址:https://ai.baidu.com/broad/download?dataset=sked
数据格式:
{
"text": "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
"spo_list":
[{
"predicate": "出生地", "object_type": "地点", "subject_type": "人物", "object": "圣地亚哥", "subject": "查尔斯·阿兰基斯"},
{
"predicate": "出生日期", "object_type": "Date", "subject_type": "人物", "object": "1989年4月17日", "subject": "查尔斯·阿兰基斯"}]}
简单来说给定一段文本,我们需要从中抽取出多组 S(subject) P(predicate) O(object_type)的关系。
例如:“查尔斯·阿兰基斯–出生日期–1989年4月17日”则是一组我们需要抽取出来的信息。而 P(需要预测的关系)已经给定范围,一共49类关系,具体见 all_50_schemas 。
这个模型思路的精彩之处:
该任务本来应该分成两个模块完成:1.抽取实体(包括S和O)2.判断实体之间的关系,理应至少需要两个模型协同完成,但苏神将实体之间的关系类别预测隐性的放在了O抽取的过程中,即让模型在预测O的时候直接预测O与S的关系P。
指针标注:对每个span的start和end进行标记,对于多片段抽取问题转化为N个2分类(N为序列长度),如果涉及多类别可以转化为层叠式指针标注(C个指针网络,C为类别总数)。事实上,指针标注已经成为统一实体、关系、事件抽取的一个“大杀器”。
由于一个文本中可能存在多对SPO关系组,甚至可能存在S之间有Overlap,O之间有Overlap的情况,因此模型的输出层使用的是半指针-半标注的sigmoid(类似多标签预测实体的始末位置,与阅读理解相似)这样可以让模型同时标注多对S和O。
使用Conditional Layer Normalization 我们需要在预测PO时告诉模型,我们的S是什么,以至于使得模型学习到PO的预测是依赖于S的,而不是看见“日期”就认为是出生年月。具体的内部实现流程也可以参考我的代码,会有介绍。(这各地方也卡了我很久才跑通)最后评估下来这个方法有利也有弊。
def load_data(path): text_list = [] spo_list = [] with open(path) as json_file: for i in json_file: text_list.append(eval(i)['text']) spo_list.append(eval(i)['spo_list']) return text_list,spo_list def load_ps(path): with open(path,'r') as f: data = pd.DataFrame([eval(i) for i in f])['predicate'] p2id = { } id2p = { } data = list(set(data)) for i in range(len(data)): p2id[data[i]] = i id2p[i] = data[i] return p2id,id2p
这里处理的思路和信息抽取(一)中处理的思路相似,有详细的代码注释:
信息抽取(一)机器阅读理解——样本数据处理与Baseline模型搭建训练(2020语言与智能技术竞赛)
这里主要介绍针对本次任务的几个细节和trick:
def proceed_data(text_list,spo_list,p2id,id2p,tokenizer,MAX_LEN): id_label = { } ct = len(text_list) MAX_LEN = MAX_LEN input_ids = np.zeros((ct,MAX_LEN),dtype='int32') attention_mask = np.zeros((ct,MAX_LEN),dtype='int32') start_tokens = np.zeros((ct,MAX_LEN),dtype='int32') end_tokens = np.zeros((ct,MAX_LEN),dtype='int32') send_s_po = np.zeros((ct,2),dtype='int32') object_start_tokens = np.zeros((ct,MAX_LEN,len(p2id)),dtype='int32') object_end_tokens = np.zeros((ct,MAX_LEN,len(p2id)),dtype='int32') invalid_index = [] for k in range(ct): context_k = text_list[k].lower().replace(' ','') enc_context = tokenizer.encode(context_k,max_length=MAX_LEN,truncation=True) if len(spo_list[k])==0: invalid_index.append(k) continue start = [] end = [] S_index = [] for j in range(len(spo_list[k])): answers_text_k = spo_list[k][j]['subject'].lower().replace(' ','') chars = np.zeros((len(context_k))) index = context_k.find(answers_text_k) chars[index:index+len(answers_text_k)]=1 offsets = [] idx=0 for t in enc_context[1:]: w = tokenizer.decode([t]) if '#'
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。