当前位置:   article > 正文

0基础学AI-基于BiLSTM+CRF的NER模型(训练,加载调用,实体获取)_bilstm+crf 什么服务器可以跑起来

bilstm+crf 什么服务器可以跑起来

本文适用于只想要代码不想要底层原理的具有特殊需要的打工仔

题主也是个刚进企业打工仔,打工仔比较关心使用,所以文章不讲各种数学操作,只关系代码操作,使用环境和一些可能会出现的基本问题。

看文章的你已经在打工就要记得,你的路径,我的路径真的不一样,不要没改就run啊

此丹炉在测试环境数据集里,迭代50代,跑了2小时+,具体多久忘了,也忘了计时,呜呜

由于所在公司服务的对象特殊性,就没有用自己的数据集举例,请大家谅解


题主的工作环境:隔绝外部网的x86架构的纯CPU搭的Linux服务器

服务器的基本环境:Anacond3,torch+torchvision

打包好的github的url连接:GitHub - jinzhangLi/BiLSTM-CRF

数据在GitHub里面有,竞赛测试数据来的,有Train.json,Test.json,dev.json(训练,测试,验证)


Warning:本文主要分为以下的几个部分,使用提醒与及须知:

Part1:为数据的加载和预处理(这里大家存的格式不一样的话,还得自己手动写一下清洗函数)

Part2:主要是模型架构的搭建和维特比算法

Part3:训练控制和训练参数设定(请注意每个人的文件路径都不一样,用的时候请注意看)

Part4:加载模型并进行使用

Part1:巧妇难为无米之炊

数据的加载处理是一切监督学习的重点,好比巧妇在洗菜(Data_locad.py)

  1. import os
  2. import pickle
  3. import json
  4. import torch
  5. #建立词表,每个词在输入到LSTM之前都需要转换成一个向量,这就是通常所说的词向量。
  6. def get_vocab(data_path,vocab_path):
  7. # 第一次运行需要遍历训练集获取到标签字典,并存储成json文件保存,第二次运行即可直接载入json文件
  8. if os.path.exists(vocab_path):
  9. with open(vocab_path, 'rb') as fp:
  10. vocab = pickle.load(fp)
  11. else:
  12. json_data = []
  13. # 加载数据集
  14. with open(data_path, 'r', encoding='utf-8') as fp:
  15. for line in fp:
  16. json_data.append(json.loads(line))
  17. # 建立词表字典,提前加入'PAD'和'UNK'
  18. # 'PAD':在一个batch中不同长度的序列用该字符补齐
  19. # 'UNK':当验证集或测试集出现词表以外的词时,用该字符代替
  20. vocab = {'PAD': 0, 'UNK': 1}
  21. # 遍历数据集,不重复取出所有字符,并记录索引
  22. for data in json_data:
  23. for word in data['text']: # 获取实体标签,如'name','company
  24. if word not in vocab:
  25. vocab[word] = len(vocab)
  26. # vocab:{'PAD': 0, 'UNK': 1, '浙': 2, '商': 3, '银': 4, '行': 5...}
  27. # 保存成pkl文件
  28. with open(vocab_path, 'wb') as fp:
  29. pickle.dump(vocab, fp)
  30. # 翻转字表,预测时输出的序列为索引,方便转换成中文汉字
  31. # vocab_inv:{0: 'PAD', 1: 'UNK', 2: '浙', 3: '商', 4: '银', 5: '行'...}
  32. vocab_inv = {v: k for k, v in vocab.items()}
  33. return vocab, vocab_inv
  34. def get_label_map(data_path,label_map_path):
  35. # 第一次运行需要遍历训练集获取到标签字典,并存储成json文件保存,第二次运行即可直接载入json文件
  36. if os.path.exists(label_map_path):
  37. with open(label_map_path, 'r', encoding='utf-8') as fp:
  38. label_map = json.load(fp)
  39. else:
  40. # 读取json数据
  41. json_data = []
  42. with open(data_path, 'r', encoding='utf-8') as fp:
  43. for line in fp:
  44. json_data.append(json.loads(line))
  45. # 统计共有多少类别
  46. n_classes = []
  47. for data in json_data:
  48. for label in data['label'].keys(): # 获取实体标签,如'name','company'
  49. if label not in n_classes: # 将新的标签加入到列表中
  50. n_classes.append(label)
  51. n_classes.sort()
  52. # n_classes: ['address', 'book', 'company', 'game', 'government', 'movie', 'name', 'organization', 'position', 'scene']
  53. # 设计label_map字典,对每个标签设计两种,如B-name、I-name,并设置其ID值
  54. label_map = {}
  55. for n_class in n_classes:
  56. label_map['B-' + n_class] = len(label_map)
  57. label_map['I-' + n_class] = len(label_map)
  58. label_map['O'] = len(label_map)
  59. # 对于BiLSTM+CRF网络,需要增加开始和结束标签,以增强其标签约束能力
  60. START_TAG = "<START>"
  61. STOP_TAG = "<STOP>"
  62. label_map[START_TAG] = len(label_map)
  63. label_map[STOP_TAG] = len(label_map)
  64. # 将label_map字典存储成json文件
  65. with open(label_map_path, 'w', encoding='utf-8') as fp:
  66. json.dump(label_map, fp, indent=4)
  67. # {0: 'B-address', 1: 'I-address', 2: 'B-book', 3: 'I-book'...}
  68. label_map_inv = {v: k for k, v in label_map.items()}
  69. return label_map, label_map_inv
  70. def data_process(path,is_train,text_lsit):
  71. # 读取每一条json数据放入列表中
  72. # 由于该json文件含多个数据,不能直接json.loads读取,需使用for循环逐条读取
  73. json_data = []
  74. with open(path, 'r', encoding='utf-8') as fp:
  75. for line in fp:
  76. json_data.append(json.loads(line))
  77. if is_train=='train':
  78. data = []
  79. # 遍历json_data中每组数据
  80. for i in range(len(json_data)):
  81. # 将标签全初始化为'O'
  82. label = ['O'] * len(json_data[i]['text'])
  83. # 遍历'label'中几组实体,如样例中'name'和'company'
  84. for n in json_data[i]['label']:
  85. # 遍历实体中几组文本,如样例中'name'下的'叶老桂'(有多组文本的情况,样例中只有一组)
  86. for key in json_data[i]['label'][n]:
  87. # 遍历文本中几组下标,如样例中[[9, 11]](有时某个文本在该段中出现两次,则会有两组下标)
  88. for n_list in range(len(json_data[i]['label'][n][key])):
  89. # 记录实体开始下标和结尾下标
  90. start = json_data[i]['label'][n][key][n_list][0]
  91. end = json_data[i]['label'][n][key][n_list][1]
  92. # 将开始下标标签设为'B-' + n,如'B-' + 'name'即'B-name'
  93. # 其余下标标签设为'I-' + n
  94. label[start] = 'B-' + n
  95. label[start + 1: end + 1] = ['I-' + n] * (end - start)
  96. # 对字符串进行字符级分割
  97. # 英文文本如'bag'分割成'b','a','g'三位字符,数字文本如'125'分割成'1','2','5'三位字符
  98. texts = []
  99. for t in json_data[i]['text']:
  100. texts.append(t)
  101. # 将文本和标签编成一个列表添加到返回数据中
  102. data.append([texts, label])
  103. elif is_train=='dev':
  104. label=None
  105. data = []
  106. # 遍历json_data中每组数据
  107. for i in range(len(json_data)):
  108. texts = []
  109. for t in json_data[i]['text']:
  110. texts.append(t)
  111. # 将文本和标签编成一个列表添加到返回数据中
  112. data.append([texts,label])
  113. else:
  114. label=None
  115. data = []
  116. for i in range(len(text_lsit)):
  117. texts=[]
  118. for j in range(len(text_lsit[i])):
  119. texts.append(text_lsit[i][j])
  120. data.append([texts,label])
  121. return data
  122. class Mydataset():
  123. def __init__(self, file_path, vocab, label_map,text_list,is_train):
  124. self.file_path = file_path
  125. # 数据预处理
  126. self.data = data_process(self.file_path,is_train,text_list)
  127. self.label_map, self.label_map_inv = label_map
  128. self.vocab, self.vocab_inv = vocab
  129. # self.data为中文汉字和英文标签,将其转化为索引形式
  130. self.examples = []
  131. if is_train=='train':
  132. for text, label in self.data:
  133. t = [self.vocab.get(t, self.vocab['UNK']) for t in text]
  134. l = [self.label_map[l] for l in label]
  135. self.examples.append([t, l])
  136. else:
  137. for text, label in self.data:
  138. t = [self.vocab.get(t, self.vocab['UNK']) for t in text]
  139. l=None
  140. self.examples.append([t, l])
  141. def __getitem__(self, item):
  142. return self.examples[item]
  143. def __len__(self):
  144. return len(self.data)
  145. def collect_fn(self, batch):
  146. # 取出一个batch中的文本和标签,将其单独放到变量中处理
  147. # 长度为batch_size,每个序列长度为原始长度
  148. text = [t for t, l in batch]
  149. label = [l for t, l in batch]
  150. # 获取一个batch内所有序列的长度,长度为batch_size
  151. seq_len = [len(i) for i in text]
  152. # 提取出最大长度用于填充
  153. max_len = max(seq_len)
  154. # 填充到最大长度,文本用'PAD'补齐,标签用'O'补齐
  155. text = [t + [self.vocab['PAD']] * (max_len - len(t)) for t in text]
  156. label = [l + [self.label_map['O']] * (max_len - len(l)) for l in label]
  157. # 将其转化成tensor,再输入到模型中,这里的dtype必须是long否则报错
  158. # text 和 label shape:(batch_size, max_len)
  159. # seq_len shape:(batch_size,)
  160. text = torch.tensor(text, dtype=torch.long)
  161. label = torch.tensor(label, dtype=torch.long)
  162. seq_len = torch.tensor(seq_len, dtype=torch.long)
  163. return text, label, seq_len
  164. def Collect_Fn(self, batch):
  165. # 取出一个batch中的文本和标签,将其单独放到变量中处理
  166. # 长度为batch_size,每个序列长度为原始长度
  167. text = [t for t, l in batch]
  168. # 获取一个batch内所有序列的长度,长度为batch_size
  169. seq_len = [len(i) for i in text]
  170. # 提取出最大长度用于填充
  171. max_len = max(seq_len)
  172. # 填充到最大长度,文本用'PAD'补齐,标签用'O'补齐
  173. text = [t + [self.vocab['PAD']] * (max_len - len(t)) for t in text]
  174. # 将其转化成tensor,再输入到模型中,这里的dtype必须是long否则报错
  175. # text 和 label shape:(batch_size, max_len)
  176. # seq_len shape:(batch_size,)
  177. text = torch.tensor(text, dtype=torch.long)
  178. seq_len = torch.tensor(seq_len, dtype=torch.long)
  179. return text, seq_len

Part2:巧妇难为无锅之炊

丹炉造的好做饭没烦恼,需要为造锅搭建丹炉(Train_model_build.py)

  1. import torch
  2. from torch import nn
  3. def argmax(vec):
  4. _, idx = torch.max(vec, 1)
  5. return idx.item()
  6. # log sum exp 增强数值稳定性
  7. # 改进了torch版本原始函数.可适用于两种情况计算得分
  8. def log_sum_exp(vec):
  9. max_score, _ = torch.max(vec, dim=-1)
  10. max_score_broadcast = max_score.unsqueeze(-1).repeat_interleave(vec.shape[-1], dim=-1)
  11. return max_score + \
  12. torch.log(torch.sum(torch.exp(vec - max_score_broadcast), dim=-1))
  13. class BiLSTM_CRF(nn.Module):
  14. def __init__(self, dataset, embedding_dim, hidden_dim, device='cpu'):
  15. super(BiLSTM_CRF, self).__init__()
  16. self.embedding_dim = embedding_dim # 词向量维度
  17. self.hidden_dim = hidden_dim # 隐层维度
  18. self.vocab_size = len(dataset.vocab) # 词表大小
  19. self.tagset_size = len(dataset.label_map) # 标签个数
  20. self.device = device
  21. # 记录状态,'train'、'eval'、'pred'对应三种不同的操作
  22. self.state = 'train' # 'train'、'eval'、'pred'
  23. self.word_embeds = nn.Embedding(self.vocab_size, embedding_dim)
  24. # BiLSTM会将两个方向的输出拼接,维度会乘2,所以在初始化时维度要除2
  25. self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=2, bidirectional=True, batch_first=True)
  26. # BiLSTM 输出转化为各个标签的概率,此为CRF的发射概率
  27. self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size, bias=False)
  28. # 初始化CRF类
  29. self.crf = CRF(dataset, device)
  30. self.dropout = nn.Dropout(p=0.5, inplace=True)
  31. self.layer_norm = nn.LayerNorm(self.hidden_dim)
  32. def _get_lstm_features(self, sentence, seq_len):
  33. embeds = self.word_embeds(sentence)
  34. self.dropout(embeds)
  35. # 输入序列进行了填充,但RNN不能对填充后的'PAD'也进行计算,所以这里使用了torch自带的方法
  36. packed = torch.nn.utils.rnn.pack_padded_sequence(embeds, seq_len, batch_first=True, enforce_sorted=False)
  37. lstm_out, _ = self.lstm(packed)
  38. seq_unpacked, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
  39. seqence_output = self.layer_norm(seq_unpacked)
  40. lstm_feats = self.hidden2tag(seqence_output)
  41. return lstm_feats
  42. def forward(self, sentence, tags, seq_len):
  43. # 输入序列经过BiLSTM得到发射概率
  44. feats = self._get_lstm_features(sentence, seq_len)
  45. # 根据 state 判断哪种状态,从而选择计算损失还是维特比得到预测序列
  46. if self.state == 'train':
  47. loss = self.crf.neg_log_likelihood(feats, tags, seq_len)
  48. return loss
  49. else:
  50. all_tag = []
  51. for i, feat in enumerate(feats):
  52. # path_score, best_path = self.crf._viterbi_decode(feat[:seq_len[i]])
  53. all_tag.append(self.crf._viterbi_decode(feat[:seq_len[i]])[1])
  54. return all_tag
  55. class CRF:
  56. def __init__(self, dataset, device='cpu'):
  57. self.label_map = dataset.label_map
  58. self.label_map_inv = dataset.label_map_inv
  59. self.tagset_size = len(self.label_map)
  60. self.device = device
  61. # 转移概率矩阵
  62. self.transitions = nn.Parameter(
  63. torch.randn(self.tagset_size, self.tagset_size)).to(self.device)
  64. # 增加开始和结束标志,并手动干预转移概率
  65. self.START_TAG = "<START>"
  66. self.STOP_TAG = "<STOP>"
  67. self.transitions.data[self.label_map[self.START_TAG], :] = -10000
  68. self.transitions.data[:, self.label_map[self.STOP_TAG]] = -10000
  69. def _forward_alg(self, feats, seq_len):
  70. # 手动设置初始得分,让开始标志到其他标签的得分最高
  71. init_alphas = torch.full((self.tagset_size,), -10000.)
  72. init_alphas[self.label_map[self.START_TAG]] = 0.
  73. # 记录所有时间步的得分,为了解决序列长度不同问题,后面直接取各自长度索引的得分即可
  74. # shape:(batch_size, seq_len + 1, tagset_size)
  75. forward_var = torch.zeros(feats.shape[0], feats.shape[1] + 1, feats.shape[2], dtype=torch.float32,
  76. device=self.device)
  77. forward_var[:, 0, :] = init_alphas
  78. # 将转移概率矩阵复制 batch_size 次,批次内一起进行计算,矩阵计算优化,加快运行效率
  79. # shape:(batch_size, tagset_size) -> (batch_size, tagset_size, tagset_size)
  80. transitions = self.transitions.unsqueeze(0).repeat(feats.shape[0], 1, 1)
  81. # 对所有时间步进行遍历
  82. for seq_i in range(feats.shape[1]):
  83. # 取出当前词发射概率
  84. emit_score = feats[:, seq_i, :]
  85. # 前一时间步得分 + 转移概率 + 当前时间步发射概率
  86. tag_var = (
  87. forward_var[:, seq_i, :].unsqueeze(1).repeat(1, feats.shape[2],
  88. 1) # (batch_size, tagset_size, tagset_size)
  89. + transitions
  90. + emit_score.unsqueeze(2).repeat(1, 1, feats.shape[2])
  91. )
  92. # 这里必须调用clone,不能直接在forward_var上修改,否则在梯度回传时会报错
  93. cloned = forward_var.clone()
  94. cloned[:, seq_i + 1, :] = log_sum_exp(tag_var)
  95. forward_var = cloned
  96. # 按照不同序列长度不同取出最终得分
  97. forward_var = forward_var[range(feats.shape[0]), seq_len, :]
  98. # 手动干预,加上结束标志位的转移概率
  99. terminal_var = forward_var + self.transitions[self.label_map[self.STOP_TAG]].unsqueeze(0).repeat(feats.shape[0],
  100. 1)
  101. # 得到最终所有路径的分数和
  102. alpha = log_sum_exp(terminal_var)
  103. return alpha
  104. def _score_sentence(self, feats, tags, seq_len):
  105. # 初始化,大小为(batch_size,)
  106. score = torch.zeros(feats.shape[0], device=self.device)
  107. # 将开始标签拼接到序列上起始位置,参与分数计算
  108. start = torch.tensor([self.label_map[self.START_TAG]], device=self.device).unsqueeze(0).repeat(feats.shape[0],
  109. 1)
  110. tags = torch.cat([start, tags], dim=1)
  111. # 在batch上遍历
  112. for batch_i in range(feats.shape[0]):
  113. # 采用矩阵计算方法,加快运行效率
  114. # 取出当前序列所有时间步的转移概率和发射概率进行相加,由于计算真实标签序列的得分,所以只选择标签的路径
  115. score[batch_i] = torch.sum(
  116. self.transitions[tags[batch_i, 1:seq_len[batch_i] + 1], tags[batch_i, :seq_len[batch_i]]]) \
  117. + torch.sum(feats[batch_i, range(seq_len[batch_i]), tags[batch_i][1:seq_len[batch_i] + 1]])
  118. # 最后加上结束标志位的转移概率
  119. score[batch_i] += self.transitions[self.label_map[self.STOP_TAG], tags[batch_i][seq_len[batch_i]]]
  120. return score
  121. # 维特比算法得到最优路径,原始torch函数
  122. def _viterbi_decode(self, feats):
  123. backpointers = []
  124. # 手动设置初始得分,让开始标志到其他标签的得分最高
  125. init_vvars = torch.full((1, self.tagset_size), -10000., device=self.device)
  126. init_vvars[0][self.label_map[self.START_TAG]] = 0
  127. # 用于记录前一时间步的分数
  128. forward_var = init_vvars
  129. # 传入的就是单个序列,在每个时间步上遍历
  130. for feat in feats:
  131. bptrs_t = [] # holds the backpointers for this step
  132. viterbivars_t = [] # holds the viterbi variables for this step
  133. # 一个标签一个标签去计算处理
  134. for next_tag in range(self.tagset_size):
  135. # 前一时间步分数 + 转移到第 next_tag 个标签的概率
  136. next_tag_var = forward_var + self.transitions[next_tag]
  137. # 得到最大分数所对应的索引,即前一时间步哪个标签过来的分数最高
  138. best_tag_id = argmax(next_tag_var)
  139. # 将该索引添加到路径中
  140. bptrs_t.append(best_tag_id)
  141. # 将此分数保存下来
  142. viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
  143. # 在这里加上当前时间步的发射概率,因为之前计算每个标签的最大分数来源与当前时间步发射概率无关
  144. forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
  145. # 将当前时间步所有标签最大分数的来源索引保存
  146. backpointers.append(bptrs_t)
  147. # 手动加入转移到结束标签的概率
  148. terminal_var = forward_var + self.transitions[self.label_map[self.STOP_TAG]]
  149. # 在最终位置得到最高分数所对应的索引
  150. best_tag_id = argmax(terminal_var)
  151. # 最高分数
  152. path_score = terminal_var[0][best_tag_id]
  153. # 回溯,向后遍历得到最优路径
  154. best_path = [best_tag_id]
  155. # print("best_path:",best_path)
  156. for bptrs_t in reversed(backpointers):
  157. best_tag_id = bptrs_t[best_tag_id]
  158. best_path.append(best_tag_id)
  159. # 弹出开始标签
  160. start = best_path.pop()
  161. assert start == self.label_map[self.START_TAG] # Sanity check
  162. # 将路径反转
  163. best_path.reverse()
  164. return path_score, best_path
  165. def neg_log_likelihood(self, feats, tags, seq_len):
  166. # 所有路径得分
  167. forward_score = self._forward_alg(feats, seq_len)
  168. # 标签路径得分
  169. gold_score = self._score_sentence(feats, tags, seq_len)
  170. # 返回 batch 分数的平均值
  171. return torch.mean(forward_score - gold_score)

Part3:巧妇学习煮饭

part1进行了数据导入和tensor化处理,然后part2是构建了模型的基本骨架,现在就是进行训练,获得模型参数,在这里我们将训练控制和训练函数写在一个py,方便大家查看,并且这里选择模型结构和参数一体化保存的操作,因为模型参数和架构分离的话,容易导致数据损坏。

  1. from Train_model_build import BiLSTM_CRF
  2. from Data_load import *
  3. from torch.utils.data import DataLoader
  4. import torch.optim as optim
  5. import torch
  6. import time
  7. from tqdm import tqdm
  8. from itertools import chain
  9. import datetime
  10. from sklearn import metrics
  11. def train(epochs, train_dataloader, valid_dataloader, model, device,optimizer, batch_size, train_dataset, model_save_path):
  12. total_start = time.time()
  13. best_score = 0
  14. for epoch in range(epochs):
  15. epoch_start = time.time()
  16. model.train()
  17. model.state = 'train'
  18. for step, (text, label, seq_len) in enumerate(train_dataloader, start=1):
  19. start = time.time()
  20. text = text.to(device)
  21. label = label.to(device)
  22. seq_len = seq_len.to(device)
  23. loss = model(text, label, seq_len)
  24. loss.backward()
  25. optimizer.step()
  26. optimizer.zero_grad()
  27. print(f'Epoch: [{epoch + 1}/{epochs}],'
  28. f' cur_epoch_finished: {step * batch_size / len(train_dataset) * 100:2.2f}%,'
  29. f' loss: {loss.item():2.4f},'
  30. f' cur_step_time: {time.time() - start:2.2f}s,'
  31. f' cur_epoch_remaining_time: {datetime.timedelta(seconds=int((len(train_dataloader) - step) / step * (time.time() - epoch_start)))}',
  32. f' total_remaining_time: {datetime.timedelta(seconds=int((len(train_dataloader) * epochs - (len(train_dataloader) * epoch + step)) / (len(train_dataloader) * epoch + step) * (time.time() - total_start)))}')
  33. # 每周期验证一次,保存最优参数
  34. score = evaluate(model, valid_dataloader, device, train_dataset)
  35. if score > best_score:
  36. print(f'score increase:{best_score} -> {score}')
  37. best_score = score
  38. torch.save(model, model_save_path)
  39. print(f'current best score: {best_score}')
  40. def evaluate(model, valid_dataloader, device, train_dataset):
  41. # model.load_state_dict(torch.load('./model1.bin'))
  42. all_label = []
  43. all_pred = []
  44. model.eval()
  45. model.state = 'eval'
  46. with torch.no_grad():
  47. for text, label, seq_len in tqdm(valid_dataloader, desc='eval: '):
  48. text = text.to(device)
  49. seq_len = seq_len.to(device)
  50. batch_tag = model(text, label, seq_len)
  51. all_label.extend([[train_dataset.label_map_inv[t] for t in l[:seq_len[i]].tolist()] for i, l in enumerate(label)])
  52. all_pred.extend([[train_dataset.label_map_inv[t] for t in l] for l in batch_tag])
  53. all_label = list(chain.from_iterable(all_label))
  54. all_pred = list(chain.from_iterable(all_pred))
  55. sort_labels = [k for k in train_dataset.label_map.keys()]
  56. # 使用sklearn库得到F1分数
  57. f1 = metrics.f1_score(all_label, all_pred, average='macro', labels=sort_labels[:-3])
  58. print(metrics.classification_report(all_label, all_pred, labels=sort_labels[:-3], digits=3))
  59. return f1
  60. def Train_control(train_path,valid_path,vocab_path,label_map_path,model_save_path,embedding_size,hidden_dim,epochs,batch_size,device):
  61. # 建立词表
  62. vocab = get_vocab(train_path, vocab_path)
  63. # 建立字典标签
  64. label_map = get_label_map(train_path, label_map_path)
  65. print("词表@标签构建完成")
  66. text_list=[]
  67. train_dataset = Mydataset(train_path, vocab, label_map,text_list,'train')
  68. valid_dataset = Mydataset(valid_path, vocab, label_map,text_list,'train')
  69. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=True,
  70. collate_fn=train_dataset.collect_fn)
  71. valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, pin_memory=False, shuffle=False,
  72. collate_fn=valid_dataset.collect_fn)
  73. model = BiLSTM_CRF(train_dataset, embedding_size, hidden_dim, device).to(device)
  74. optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
  75. train(epochs, train_dataloader, valid_dataloader, model, device,
  76. optimizer, batch_size, train_dataset, model_save_path)
  77. if __name__=='__main__':
  78. torch.manual_seed(42)
  79. embedding_size = 128
  80. hidden_dim = 768
  81. epochs = 100
  82. batch_size = 32
  83. device = "cpu"
  84. # 训练集和验证集地址导入
  85. train_path = '/home/ModelTrain/NLP/Data/new_train.json'
  86. valid_path = '/home/ModelTrain/NLP/Data/new_dev.json'
  87. # 词表保存路径
  88. vocab_path = '/home/ModelTrain/NLP/Data/vocab.pkl'
  89. # 标签字典保存路径
  90. label_map_path = '/home/ModelTrain/NLP/Data/label_map.json'
  91. # 模型保存的路径
  92. model_save_path = '/home/ModelTrain/NLP/Data/BiLSTM+CRF.h5'
  93. Train_control(train_path,valid_path,vocab_path,label_map_path,model_save_path,embedding_size,hidden_dim,epochs,batch_size,device)

Part4:有锅有米开始做菜

模型加载和调用,与进入模型前数据的预处理,这里也没有用bert,主要是题主发觉业务数据的内容使用Bert反而好像因为预料的原因不是很好用

  1. import time
  2. import torch
  3. from torch.utils.data import DataLoader
  4. from itertools import chain
  5. import json
  6. from Data_load import *
  7. def vector2text(string,predict):
  8. # 标签转录BIO格式
  9. item = {"string": string, "entities": []}
  10. entity_name = ""
  11. flag,items= [],[]
  12. visit = False
  13. for char, tag in zip(string, predict):
  14. if tag[0] == "B":
  15. if entity_name != "":
  16. x = dict((a, flag.count(a)) for a in flag)
  17. y = [k for k, v in x.items() if max(x.values()) == v]
  18. item["entities"].append({"word": entity_name, "type": y[0]})
  19. items.append([entity_name, y[0]])
  20. flag.clear()
  21. entity_name = ""
  22. visit = True
  23. entity_name += char
  24. flag.append(tag[2:])
  25. elif tag[0] == "I" and visit:
  26. entity_name += char
  27. flag.append(tag[2:])
  28. else:
  29. if entity_name != "":
  30. x = dict((a, flag.count(a)) for a in flag)
  31. y = [k for k, v in x.items() if max(x.values()) == v]
  32. item["entities"].append({"word": entity_name, "type": y[0]})
  33. items.append([entity_name, y[0]])
  34. flag.clear()
  35. flag.clear()
  36. visit = False
  37. entity_name = ""
  38. if entity_name != "":
  39. x = dict((a, flag.count(a)) for a in flag)
  40. y = [k for k, v in x.items() if max(x.values()) == v]
  41. item["entities"].append({"word": entity_name, "type": y[0]})
  42. items.append([entity_name,y[0]])
  43. return items
  44. def data_get(data_path):
  45. # 读书数据Json,存入一个列表,元素为输入的每一句话
  46. with open(data_path, 'r', encoding='utf-8') as fp:
  47. json_data=[json.loads(line) for line in fp]
  48. texts = [''.join([t for t in json_data[i]['text']]) for i in range(len(json_data))]
  49. return texts
  50. def predict(vocab_path,label_map_path,data_path,model_path,device,model_state,text_list):
  51. start=time.time()
  52. # 建立词表
  53. vocab = get_vocab('0', vocab_path)
  54. # 建立字典标签
  55. label_map = get_label_map('0', label_map_path)
  56. global label_map_index
  57. for i in range(len(label_map)):
  58. label_map_index=label_map[i]
  59. dataset = Mydataset(data_path, vocab, label_map, text_list,'use')
  60. dataloader = DataLoader(dataset, batch_size=1, num_workers=0, pin_memory=False, shuffle=False,
  61. collate_fn=dataset.Collect_Fn)
  62. model=torch.load(model_path,map_location=device)
  63. model.eval()
  64. model.state=model_state
  65. result=[]
  66. with torch.no_grad():
  67. k = -1
  68. for text, seq_len in dataloader:
  69. k=k+1
  70. text = text.to(device)
  71. seq_len = seq_len.to(device)
  72. batch_tag = model(text,None, seq_len)
  73. predict=[[label_map_index[t] for t in l] for l in batch_tag]
  74. for i in range(len(predict)):
  75. items=vector2text(text_list[k*len(predict)+i], predict[i])
  76. result.append([text_list[k*len(predict)+i]]+items)
  77. for i in range(len(result)):
  78. print(result[i])
  79. end = time.time()
  80. time_s=end-start
  81. print("******Using Time:"+str(time_s)+"******")
  82. # 调用 load.h5
  83. vocab_path = '/home/ModelTrain/NLP/Data/vocab.pkl'
  84. label_map_path = '/home/ModelTrain/NLP/Data/label_map.json'
  85. data_path = '/home/ModelTrain/NLP/Data/new_test.json'
  86. model_path = '/home/ModelTrain/NLP/Data/BiLSTM+CRF.h5'
  87. device = "cpu"
  88. model_state='eval'
  89. text_list=['警情通报', '近日', '忻府区又有多位居民被电信网络诈骗', '丽都锦城小区居民李女士在网络平台刷单被诈骗396240元', '田森汇小区居民陈女士在网络平台刷单被诈骗33000元', '机械局宿舍居民于女士在网络平台刷单被诈骗20000元', '解原乡乔村马先生被冒充客服以退款为由诈骗170000元', '鑫立佳苑小区居民王女士在网络平台投资理财被诈骗510000元', '警方提示', '1、刷单', '刷单', '还是刷单被骗你想得骗子的返利', '骗子想得你的本金', '你相信骗子说的再刷一单连本带利就返还了', '骗子想到底还能刷多少钱进账就能把你拉黑了', '2、接到“客服”电话、短信等', '到正规平台核实后再进行操作', '防止上当受骗', '3、网络投资理财要选择官方、正规的投资平台', '切勿用“好友”发送的链接或二维码下载陌生APP进行理财', '这类理财软件只能看到数字在平台增加', '永远不能提现', '忻州市公安局直属分局', '2023年5月24日']
  90. time_s=predict(vocab_path,label_map_path,data_path,model_path,device,model_state,text_list)
  91. #直接下载一个模型和参数在一起的.h5

通过上述的所有Part,我们成功构建了一个基于BiLSTM+CRF的命名实体提取模型,并进行模型训练和模型加载调用,我们就可以依赖我们生产中的语料培训一个适合业务的命名实体提取模型,可能之前的调用部分有所披露,因此题主修改了一波,目前的这个版本比较好用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/592400
推荐阅读
相关标签
  

闽ICP备14008679号