当前位置:   article > 正文

命名实体识别(NER)(二):BERT+CRF模型训练

bert+crf

摘要:上篇介绍了数据的标注过程,接下来就是模型的训练了,本文采用BERT+CRF模型进行训练。

 

采用kashgari模块(一个将Bert封装好的模块,用于快速搭建模型)的Bert模块快速搭建自己模型

在原来bert预训练模型基础上,导入自己的标注数据进一步训练,

bert中文预训练模型就自己下一下吧~

代码比较简单,就是导入数据,在预训练基础上进一步训练

  1. from kashgari.tasks.seq_labeling import BLSTMCRFModel
  2. from kashgari.embeddings import BERTEmbedding
  3. import kashgari
  4. from kashgari import utils
  5. import os
  6. #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. #os.environ["CUDA_VISIBLE_DEVICES"] = ""
  8. def get_sequence_tagging_data(file_path):
  9. data_x, data_y = [], []
  10. with open(file_path, 'r', encoding='utf-8') as f:
  11. lines = f.read().splitlines()
  12. # print(lines)
  13. x, y = [], []
  14. for line in lines:
  15. rows = line.split(' ')
  16. if len(rows) == 4:
  17. data_x.append(x)
  18. data_y.append(y)
  19. x = []
  20. y = []
  21. else:
  22. x.append(rows[0])
  23. y.append(rows[1])
  24. return data_x, data_y
  25. #train_x, train_y = get_sequence_tagging_data('training_data_bert_train.txt')
  26. # train_x, train_y = get_sequence_tagging_data('../new_note.txt')
  27. # print(f"train data count: {len(train_x)}")
  28. model training
  29. embedding = BERTEmbedding('/pvc/train/chinese_L-12_H-768_A-12',40)
  30. model = BLSTMCRFModel(embedding)
  31. model.fit(train_x,
  32. train_y,
  33. validation_split = 0.4,
  34. epochs=10,
  35. batch_size=32)
  36. print('model_save')
  37. model.save('../model_save/ner_model')

最后附上一段供测试的代码

  1. load_model = BLSTMCRFModel.load_model('../model_save/ner_model')
  2. print(load_model.predict("刘若英语怎么样"))

从训练数据的标注至NER模型训练流程就算简单走一遍了。

至此,NER本节就算完结了~逃

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/803356
推荐阅读
相关标签
  

闽ICP备14008679号