赞
踩
三元组抽取任务,基于“半指针-半标注”结构
文章介绍:https://kexue.fm/archives/7161
数据集:http://ai.baidu.com/broad/download?dataset=sked
最优f1=0.82198
换用RoBERTa Large可以达到f1=0.829+
说明:由于使用了EMA,需要跑足够多的步数(5000步以上)才生效,如果你的数据总量比较少,那么请务必跑足够多的epoch数,或者去掉EMA。
import json import numpy as np from bert4keras.backend import keras, K, batch_gather from bert4keras.layers import Loss from bert4keras.layers import LayerNormalization from bert4keras.tokenizers import Tokenizer from bert4keras.models import build_transformer_model from bert4keras.optimizers import Adam, extend_with_exponential_moving_average from bert4keras.snippets import sequence_padding, DataGenerator from bert4keras.snippets import open, to_array from keras.layers import Input, Dense, Lambda, Reshape from keras.models import Model from tqdm import tqdm maxlen = 128 batch_size = 64 config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
def load_data(filename): """加载数据 单条格式:{'text': text, 'spo_list': [(s, p, o)]} """ D = [] with open(filename, encoding='utf-8') as f: for l in f: l = json.loads(l) D.append({ 'text': l['text'], 'spo_list': [(spo['subject'], spo['predicate'], spo['object']) for spo in l['spo_list']] }) return D # 加载数据集 train_data = load_data('/root/kg/datasets/train_data.json') valid_data = load_data('/root/kg/datasets/dev_data.json') predicate2id, id2predicate = { }, { } with open('/root/kg/datasets/all_50_schemas') as f: for l in f: l = json.loads(l) if l['predicate'] not in predicate2id: id2predicate[len(predicate2id)] = l['predicate'] predicate2id[l['predicate']] = len(predicate2id) # 建立分词器 tokenizer = Tokenizer(dict_path, do_lower_case=True)
def search(pattern, sequence): """从sequence中寻找子串pattern 如果找到,返回第一个下标;否则返回-1。 """ n = len(pattern) for i in range(len(sequence)): if sequence[i:i + n] == pattern: return i return -1 class data_generator(DataGenerator): """数据生成器 """ def __iter__(self, random=False): batch_token_ids, batch_segment_ids = [], [] batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] for is_end, d in self.sample(random): token_ids, segment_ids = tokenizer.encode(d['text'], maxlen=maxlen) # 整理三元组 {s: [(o, p)]} spoes = { } for s, p, o in d['spo_list']: s = tokenizer.encode(s)[0][1:-1] p = predicate2id[p] o = tokenizer.encode
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。