赞
踩
以前通过模板规则的方式进行命名实体的提取,优点是提取速度非常高,但模板规则存在冲突的情况,尝试过使用百度LAC通过词性模板规则进行命名实体的提取,好处是少量规则可以覆盖大部分情况,但也存在规则冲突的情况。本文尝试采用Bert+BiLSTM+CRF的方式进行命名实体的提取。使用Bert的好处是能够学习到语料的语义特征,BiLSTM能学习到词之间较长的上下文关系,CRF能纠正BiLSTM预测的顺序错误。Bert的好处是准确率非常高,缺点也很明显,推理速度低,可以通过部署的方式来提升推理性能,如:使用ONNX 运行环境。
主要步骤如下:
1)准备标注语料(自行准备了224个标注),生成和人民日报语料一样的格式(语料生成代码来自互联网),可以自定义领域命名实体;
- #生成的训练语料,一个字一行,格式同人民日报语料
- import re
-
- # txt2ner_train_data turn label str into ner trainable data
- # s :labeled str eg.'我来到[@1999年#YEAR*]的[@上海#LOC*]的[@东华大学#SCHOOL*]'
- # save_path: ner_trainable_txt name
- def str2ner_train_data(s, save_path):
- ner_data = []
- result_1 = re.finditer(r'\[\@', s)
- result_2 = re.finditer(r'\*\]', s)
- begin = []
- end = []
- for each in result_1:
- begin.append(each.start())
- for each in result_2:
- end.append(each.end())
- print(len(begin) ,len(end))
- assert len(begin) == len(end)
- i = 0
- j = 0
- while i < len(s):
- if i not in begin:
- ner_data.append([s[i], 'O'])
- i = i + 1
- else:
- ann = s[i + 2:end[j] - 2]
- entity, ner = ann.rsplit('#')
- if (len(entity) == 1):
- ner_data.append([entity, 'B-' + ner])
- # ner_data.append([entity, 'S-' + ner])
- else:
- if (len(entity) == 2):
- ner_data.append([entity[0], 'B-' + ner])
- ner_data.append([entity[1], 'I-' + ner])
- # ner_data.append([entity[1], 'E-' + ner])
- else:
- ner_data.append([entity[0], 'B-' + ner])
- for n in range(1, len(entity)):
- ner_data.append([entity[n], 'I-' + ner])
- # ner_data.append([entity[-1], 'E-' + ner])
-
- i = end[j]
- j = j + 1
-
- f = open(save_path, 'a', encoding='utf-8')
- for each in ner_data:
- f.write(each[0] + ' ' + str(each[1]))
- if each[0] == '。' or each[0] == '?' or each[0] == '!':
- f.write('\n')
- f.write('\n')
- else:
- f.write('\n')
- f.close()
-
-
- # txt2ner_train_data turn label str into ner trainable data
- # file_path :labeled multi lines' txt eg.'我来到[@1999年#YEAR*]的[@上海#LOC*]的[@东华大学#SCHOOL*]'
- # save_path: ner_trainable_txt name
- def txt2ner_train_data(file_path, save_path):
- fr = open(file_path, 'r', encoding='utf-8')
- lines = fr.readlines()
- s = ''
- for line in lines:
- line = line.replace('\n', '')
- line = line.replace(' ', '')
- s = s + line
- fr.close()
- str2ner_train_data(s, save_path)
- if (__name__ == '__main__'):
- train_path = './train.txt' #生成的训练语料,一个字一行,格式同人民日报语料
- corpus_path = './middle_corpus.txt'#根据领域特征标注语料,可以自定义NER标签,不限于PER(人名),LOC(地名),ORG(机构名)
- txt2ner_train_data(corpus_path, train_path)
- # 读取自己的预料’
- train_path = './train.txt'
- test_path = './test.txt'
-
- def get_sequenct_tagging_data(file_path):
- data_x, data_y = [], []
-
- with open(file_path, 'r', encoding='utf-8') as f:
- lines = f.read().splitlines()
-
- x, y = [], []
- for line in lines:
- rows = line.split(' ')
- if len(rows) == 1:
- data_x.append(x)
- data_y.append(y)
- x = []
- y = []
- else:
- x.append(rows[0])
- y.append(rows[1])
- return data_x, data_y
-
- train_x, train_y = get_sequenct_tagging_data(train_path)
- validate_x, validate_y = get_sequenct_tagging_data(test_path)
2)使用kashgari2.0.1用于快速使用模型进行训练,包括使用Bert作为特征提取,使用中文预训练模型chinese_L-12_H-768_A-12(需要自行下载到本地);
3)模型的保存与装载;
4)使用模型进行推理,推理效果相当不错,比百度LAC的效果好。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。