当前位置:   article > 正文

中文医学知识图谱中的命名实体识别和关系抽取源码学习

中文医学知识图谱

目录

引言

一、模型定义与下载

二、依赖库和配置文件

三、数据训练过程 

1、关系类别加载过程

2、数据加载过程

3、数据迭代过程

4、主体预测过程

5、主体与客体关系预测过程

6、实际训练过程

7、损失函数计算过程

8、模型评估过程

四、程序运行结果

1、加载模型

2、接收文本内容以及命名实体识别模型和关系抽取模型

3、提取输入文本中的命名实体及其关系

4、运行结果 


引言

中文医学知识图谱(Chinese Medical Knowledge Graph, CMeKG)是利用自然语言处理与文本挖掘技术,基于大规模医学文本数据,以人机结合的方式研发的中文医学知识图谱。
此项目中主要模型工具包括:医学实体识别和医学关系抽取。项目源码地址


一、模型定义与下载

由于依赖和训练好的的模型较大,将模型放到了百度网盘中,链接如下,按需下载。

命名实体识别(Named Entity Recognition,NER),是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名、专有名词等,以及时间、数量、货币、比例数值等文字。即用专有名词(名称)标识的事物,一个命名实体一般代表唯一一个具体事物个体,预测文本中具有特定意义的实体。

NER:链接:https://pan.baidu.com/s/16TPSMtHean3u9dJSXF9mTw  密码:shwh

关系抽取(Relation Extraction, RE)是指若有两个存在着关系的实体,将两个实体分别成为主体和客体,那么关系抽取就是在非结构或半结构化数据中找出主体与客体之间存在的关系,并将其表示为实体关系三元组抽取(主体Subject,谓语Predicate,客体Object),即通过NER得到了实体之后,预测任意两个实体存在怎样的关系。

RE:链接:https://pan.baidu.com/s/1cIse6JO2H78heXu7DNewmg  密码:4s6k


二、依赖库和配置文件

依赖库代码如下 import:

  1. import gc # 用于释放不再使用的内存资源
  2. import json # 用于处理 JSON 格式的数据
  3. import random # 用于模型训练时数据集的随机化
  4. import re # 用于文本数据的预处理,去除特殊字符或提取特定模式的信息
  5. import time # 用于记录训练时间
  6. from itertools import cycle # 用于循环遍历数据集,确保每个样本都能被处理到
  7. import numpy as np # 通常与PyTorch一起使用,用于处理数值数据和数组
  8. import torch # 提供张量操作、模型定义、梯度计算等功能
  9. import torch.nn as nn # PyTorch中的神经网络模块,用于定义和训练神经网络模型
  10. # 来自 Hugging Face Transformers 库的模块,用于处理基于 BERT 模型的自然语言处理任务。BertTokenizer 用于对文本进行分词,BertModel 是预训练的 BERT 模型,而 AdamW 是一种优化器,常用于调整模型参数以最小化损失函数。
  11. from transformers import BertTokenizer, BertModel, AdamW
  12. import warnings
  13. warnings.filterwarnings('ignore') # 去掉红色的FutureWarning警告提示

配置文件代码如下 config:

  1. class config:
  2. """
  3. 配置类,包含模型训练和运行时的参数
  4. """
  5. batch_size = 1 # 批处理大小(根据自己电脑配置进行修改)
  6. max_seq_len = 256 # 最大序列长度256个字符
  7. num_p = 23 # 关系种类数(数据中自定义)
  8. learning_rate = 1e-5 # 学习率
  9. EPOCH = 2 # 训练轮次
  10. # 文件路径
  11. PATH_SCHEMA = "/predicate.json" # 关系种类文件路径
  12. PATH_TRAIN = '/train_example.json' # 训练数据文件路径
  13. PATH_BERT = "/bert_model/config.json" # BERT模型文件夹路径
  14. PATH_MODEL = "/model_re.pkl" # 训练好的关系抽取模型文件,包含了模型参数、优化器参数等
  15. PATH_SAVE = '/save' # 模型保存路径
  16. tokenizer = BertTokenizer.from_pretrained("/bert_model/vocab.txt")
  17. id2predicate = {} # 关系id到名称的映射
  18. predicate2id = {} # 关系名称到id的映射

三、数据训练过程 

训练数据过程代码如下:

  1. def run_train():
  2. """
  3. 训练数据过程
  4. """
  5. load_schema(config.PATH_SCHEMA) # 调用load_schema()函数,读取predicate.json
  6. train_path = config.PATH_TRAIN # 表示训练数据的路径,读取train_example.json
  7. all_data = load_data(train_path) # 加载训练数据,返回包含数据和标签的列表。调用load_data()函数,读取train_example.json所有数据
  8. random.shuffle(all_data) # 打乱训练数据的顺序,以增加模型的泛化性能
  9. idx = int(len(all_data) * 0.8) # 计算训练集和验证集的分割点,8:2划分训练集和验证集,idx = 19
  10. train_data = all_data[:idx] # 获取训练集,即前80%的数据,随机的19个训练集
  11. valid_data = all_data[idx:] # 获取验证集,即剩余20%的数据,剩下的5个验证集
  12. # train训练
  13. train_data_loader = IterableDataset(train_data, True) # 创建可迭代的数据加载器,用于训练模型,调用IterableDataset()函数
  14. num_train_data = len(train_data) # 获取训练数据的数量,num_train_data = 19
  15. checkpoint = torch.load(config.PATH_MODEL) # 加载预训练的模型的检查点,包含了模型参数、优化器参数等,读取model_re.pkl
  16. model4s = Model4s() # 创建用于预测主体的模型,调用Model4s()函数
  17. model4s.load_state_dict(checkpoint['model4s_state_dict']) # 加载预训练主体预测模型的参数
  18. model4po = Model4po() # 创建用于计算主体之间关系的模型,调用Model4po()函数
  19. model4po.load_state_dict(checkpoint['model4po_state_dict']) # 加载预训练关系计算模型的参数
  20. # 模型学习与迭代的权重衰减策略
  21. param_optimizer = list(model4s.named_parameters()) + list(model4po.named_parameters())
  22. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  23. optimizer_grouped_parameters = [
  24. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  25. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  26. ]
  27. lr = config.learning_rate
  28. optimizer = AdamW(optimizer_grouped_parameters, lr=lr) # 构建模型学习与迭代的权重衰减策略,配置优化器AdamW
  29. optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载预训练优化器的参数
  30. checkpoint = train(train_data_loader, model4s, model4po, optimizer) # 调用train()函数,,并获取训练后的检查点
  31. del train_data # 删除训练数据变量,释放内存
  32. gc.collect() # 释放不再使用的内存
  33. model_path = config.PATH_SAVE # 保存训练后的模型检查点
  34. torch.save(checkpoint, model_path)
  35. print('saved!')
  36. model4s.eval() # 将模型设置为评估模式,用于验证
  37. model4po.eval()
  38. f1, precision, recall = evaluate(valid_data, True, model4s, model4po) # 调用 evaluate() 函数对验证集进行评估,计算 F1 值、精确度和召回率
  39. print('f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall))

1、关系类别加载过程

列出来所有的关系及其对应的id,代码如下

  1. def load_schema(path):
  2. """
  3. 列出来的所有关系的数量和类别(自己做项目可以自定义predicate.josn文件)及其对应的id
  4. """
  5. with open(path, 'r', encoding='utf-8', errors='replace') as f:
  6. # 从文件中加载关系数据,data为一个字典,包含关系类别和对应数量
  7. data = json.load(f) # data = {'相关疾病': 12410, '相关症状': 10880, '临床表现': 94657, '检查': 8543, '用法用量': 10914, ...}
  8. # 获取关系类别列表,predicate为所有关系类别的集合
  9. predicate = list(data.keys()) # predicate = ['相关疾病','相关症状', '临床表现', '检查', '用法用量',...]
  10. # 创建一个字典,用于将关系类别predicate映射为对应的id
  11. prediction2id = {} # prediction2id = {'相关疾病': 0,'相关症状': 1, '临床表现': 2, '检查': 3, '用法用量': 4,...]
  12. # 创建一个字典,用于将id映射回关系类别predicate
  13. id2predicate = {} # id2predicate = {0: '相关疾病',1: '相关症状', 2: '临床表现', 3: '检查',4: '用法用'量,...}
  14. # 遍历关系类别列表
  15. for i in range(len(predicate)): # 将关系类别映射为id,建立关系类别到id的映射,i = 22
  16. prediction2id[predicate[i]] = i # 将关系类别映射为id,建立关系类别到id的映射
  17. id2predicate[i] = predicate[i] # 将id映射回关系类别,建立id到关系类别的映射
  18. num_p = len(predicate) # 获取关系类别的数量,即总共有多少种关系,num_p = 23
  19. config.prediction2id = prediction2id # 将关系类别到id的映射存储在配置文件中
  20. config.id2predicate = id2predicate # 将id到关系类别的映射存储在配置文件中
  21. config.num_p = num_p # 将关系类别的数量存储在配置文件中

2、数据加载过程

加载数据代码如下:

  1. def load_data(path): # 定义了加载数据的函数,接受一个文件路径参数,path = 'train_example.json'
  2. """
  3. 加载数据:把输入数据和spo三元组放到字典当中
  4. """
  5. text_spos = [] # 初始化一个空列表,用于存储每个样本的文本和对应的spo三元组
  6. with open(path, 'r', encoding='utf-8', errors='replace') as f:
  7. data = json.load(f) # 从文件中加载JSON格式的数据,存储在data变量中
  8. # data =
  9. # [
  10. # {
  11. # "text": "12小时尿沉渣计数的相关疾病:单纯型尿路感染,妊娠合并急性膀胱炎,慢性肾炎,狼疮性肾炎,急性膀胱炎12小时尿沉渣计数的相关症状是高血压,男子性功能障碍,蛋白尿,血尿,水肿,排尿困难及尿潴留,尿频伴尿急和尿痛",
  12. # "spo_list": [
  13. # [
  14. # "12小时尿沉渣计数",
  15. # "相关疾病",
  16. # "单纯型尿路感染"
  17. # ],
  18. # [
  19. # "12小时尿沉渣计数",
  20. # "相关疾病",
  21. # "妊娠合并急性膀胱炎"
  22. # ],
  23. # [
  24. # "12小时尿沉渣计数",
  25. # "相关疾病",
  26. # "慢性肾炎"
  27. # ],
  28. # [
  29. # "12小时尿沉渣计数",
  30. # "相关疾病",
  31. # "狼疮性肾炎"
  32. # ],
  33. # [
  34. # "12小时尿沉渣计数",
  35. # "相关疾病",
  36. # "急性膀胱炎"
  37. # ],
  38. # [
  39. # "12小时尿沉渣计数",
  40. # "相关症状",
  41. # "高血压"
  42. # ],
  43. # [
  44. # "12小时尿沉渣计数",
  45. # "相关症状",
  46. # "男子性功能障碍"
  47. # ],
  48. # [
  49. # "12小时尿沉渣计数",
  50. # "相关症状",
  51. # "蛋白尿"
  52. # ],
  53. # [
  54. # "12小时尿沉渣计数",
  55. # "相关症状",
  56. # "血尿"
  57. # ],
  58. # [
  59. # "12小时尿沉渣计数",
  60. # "相关症状",
  61. # "水肿"
  62. # ],
  63. # [
  64. # "12小时尿沉渣计数",
  65. # "相关症状",
  66. # "排尿困难及尿潴留"
  67. # ],
  68. # [
  69. # "12小时尿沉渣计数",
  70. # "相关症状",
  71. # "尿频伴尿急和尿痛"
  72. # ],
  73. # [
  74. # "12小时尿沉渣计数",
  75. # "相关疾病",
  76. # "肾炎"
  77. # ],
  78. # [
  79. # "12小时尿沉渣计数",
  80. # "相关疾病",
  81. # "尿路感染"
  82. # ],
  83. # [
  84. # "12小时尿沉渣计数",
  85. # "相关症状",
  86. # "排尿困难"
  87. # ],
  88. # [
  89. # "12小时尿沉渣计数",
  90. # "相关症状",
  91. # "尿潴留"
  92. # ],
  93. # [
  94. # "12小时尿沉渣计数",
  95. # "相关症状",
  96. # "尿频"
  97. # ],
  98. # [
  99. # "12小时尿沉渣计数",
  100. # "相关症状",
  101. # "尿急"
  102. # ],
  103. # [
  104. # "12小时尿沉渣计数",
  105. # "相关症状",
  106. # "尿痛"
  107. # ]
  108. # ]
  109. # },...
  110. for item in data: # 遍历JSON数据中的每个样本
  111. # 获取文本内容和spo三元组
  112. # 获取当前样本的文本内容,text = "12小时尿沉渣计数的相关疾病:单纯型尿路感染,妊娠合并急性膀胱炎,慢性肾炎,狼疮性肾炎,急性膀胱炎12小时尿沉渣计数的相关症状是高血压,男子性功能障碍,蛋白尿,血尿,水肿,排尿困难及尿潴留,尿频伴尿急和尿痛"...
  113. text = item['text']
  114. # 获取当前样本的spo三元组列表,spo_list = [['12小时尿沉渣计数, 相关疾病, 单纯型尿路感染'], ['12小时尿沉渣计数, 相关疾病, 妊娠合并急性膀胱炎'], ['12小时尿沉渣计数, 相关疾病, 慢性肾炎'],...]
  115. spo_list = item['spo_list']
  116. # 将文本和spo三元组以字典的形式添加到text_spos列表中
  117. text_spos.append({
  118. 'text': text,
  119. 'spo_list': spo_list
  120. })
  121. return text_spos # 返回存储文本和对应spo三元组的列表

3、数据迭代过程

 迭代数据代码ru

  1. class IterableDataset(torch.utils.data.IterableDataset):
  2. """
  3. 迭代数据:定义了一个继承自PyTorch可迭代数据集的类,用于处理命名实体识别和关系抽取的任务
  4. """
  5. def __init__(self, data, random): # 初始化函数,接收数据和随机化数据
  6. super(IterableDataset).__init__() # 调用父类的初始化函数
  7. self.data = data # data = all_data(接收已经加载的所有数据)
  8. self.random = random # 随机打乱数据
  9. self.tokenizer = config.tokenizer # 使用配置文件中的分词器
  10. def __len__(self):
  11. return len(self.data) # 定义len()方法,返回数据集的长度
  12. # (sequence =)batch_token_ids = [[101 8123 4905 3710 1825 7000 5868 5843 5131 3800 2198 3890
  13. # 4638 679 5679 1353 2418 131 6783 3800 2428 6814 2571 3198 1377 772 4495
  14. # 2626 2552 510 1445 1402 510 1355 4178 5023 4568 4307 8039 1453 1741
  15. # 7474 5549 4017 3800 6862 2428 6814 2571 3198 981 1377 1355 4495 7474
  16. # 5549 4142 8024 2418 1217 3800 2692 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0,...]]63
  17. # pattern = 主体s的编码 s_rand = [8123 4905 3710 1825 7000 5868 5843 5131 3800 2198 3890]
  18. # 定义搜索子序列的方法,用于找到主体的起始位置
  19. def search(self, sequence, pattern):
  20. n = len(pattern)
  21. for i in range(len(sequence)):
  22. if sequence[i:i + n] == pattern:
  23. return i
  24. return -1 # 如果能找到则返回 pattern 开始的位置,找不到则返回-1
  25. # 定义处理数据的方法,包括数据的预处理和标签的生成)
  26. def process_data(self):
  27. idxs = list(range(len(self.data))) # 生成数据索引列表 idxs = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18]
  28. if self.random: # 判断是否需要随机化数据顺序
  29. np.random.shuffle(idxs) # 如果需要,随机打乱数据idxs顺序,idxs = [5, 9, 4, 8, 18, 7, 0, 6, 15, 17, 11, 1, 14, 2, 16, 10, 13, 12, 3]
  30. batch_size = config.batch_size # 获取批次大小 batch_size = 1
  31. max_seq_len = config.max_seq_len # 获取最大序列长度 max_seq_len = 256(如果不设置的话原模型设置的最大长度是512,如果句子长度超过512会报错)
  32. num_p = config.num_p # 获取类别数量 num_p = 23
  33. # 1、初始化各类数据数组:batch_token_ids【1,256】,一段话最多有256个字符,每个字符所对应的id
  34. batch_token_ids = np.zeros((batch_size, max_seq_len), dtype=np.int) # 初始化:batch_token_ids = [[00000000000000000000000, 00000000000000000000000,...]]
  35. # 2、哪些字符要参与到Transformer和self attention中的计算当中
  36. batch_mask_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
  37. # 3、一篇文字有多句话组成,指定是第几句话
  38. batch_segment_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
  39. # 4、找到主体开始S和结束的位置E,输入样本是一个主体,而不是一个句子,训练spo模型的能力(一个主体和其他个体之间的关系)
  40. batch_subject_ids = np.zeros((batch_size, 2), dtype=np.int) # batch_subject_ids = [[0 0]]
  41. # 5、定义标签【1,256,2】,每个主体都可以成为一个主体的起始和终止位置
  42. batch_subject_labels = np.zeros((batch_size, max_seq_len, 2), dtype=np.int) # batch_subject_labels = [[[0, 0], [0, 0], [0, 0],...]]
  43. # 6、确定主体后,找谁是它对应的客体和他们之间的关系(属于23当中的哪一个)
  44. batch_object_labels = np.zeros((batch_size, max_seq_len, num_p, 2), dtype=np.int)
  45. batch_i = 0
  46. # 遍历数据集索引(在for循环中首先从字典中取出所需处理的文本,然后使用 tokenizer.encode() 进行解析,返回切分之后的编码情况)
  47. for i in idxs: # i = 5
  48. # text = "18种氨基酸葡萄糖注射液的不良反应:输注速度过快时可产生恶心、呕吐、发热等症状;周围静脉滴注速度过快时偶可发生静脉炎,应加注意"
  49. text = self.data[i]['text'] # 获取当前数据文本
  50. # 使用分词器tokenizer,将文本字符转换成对应的id,并进行padding操作(超过256做截断,不足256做补0)
  51. batch_token_ids[batch_i, :] = self.tokenizer.encode(text, max_length=max_seq_len, pad_to_max_length=True, add_special_tokens=True)
  52. # batch_token_ids = [[101 8123 4905 3710 1825 7000 5868 5843 5131 3800 2198 3890
  53. # 4638 679 5679 1353 2418 131 6783 3800 2428 6814 2571
  54. # 3198 1377 772 4495 2626 2552 510 1445 1402 510 1355 4178
  55. # 5023 4568 4307 8039 1453 1741 7474 5549 4017 3800 6862 2428 6814 2571 3198 981 1377 1355 4495 7474
  56. # 5549 4142 8024 2418 1217 3800 2692 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0,...]]63
  57. # batch_mask_ids = [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...65个1...0 0 0 0 0 ]] 【1×256】
  58. batch_mask_ids[batch_i, :len(text) + 2] = 1 # 设置对应的mask(将里面文本长度加2的值置1表明文本长度,预留出开头和结尾的输出)
  59. # spo_list = [
  60. # [
  61. # "18种氨基酸葡萄糖注射液",
  62. # "不良反应",
  63. # "呕吐"
  64. # ],
  65. # [
  66. # "18种氨基酸葡萄糖注射液",
  67. # "不良反应",
  68. # "恶心"
  69. # ],
  70. # [
  71. # "18种氨基酸葡萄糖注射液",
  72. # "不良反应",
  73. # "发热"
  74. # ],
  75. # [
  76. # "18种氨基酸葡萄糖注射液",
  77. # "不良反应",
  78. # "恶心、呕吐"
  79. # ]
  80. # ]
  81. # },
  82. spo_list = self.data[i]['spo_list']
  83. idx = np.random.randint(0, len(spo_list), size=1)[0] # 随机选择一个主体并获取其起始位置,然后在文本中进行匹配,找到该主体位置并记录。相当于每次都是随机选一个主体s来组成数据,即"18种氨基酸葡萄糖注射液"
  84. # 处理标签,包括主体和客体的标签
  85. s_rand = self.tokenizer.encode(spo_list[idx][0])[1:-1] # 主体s的编码 s_rand = [8123 4905 3710 1825 7000 5868 5843 5131 3800 2198 3890] 不取101和102
  86. s_rand_idx = self.search(list(batch_token_ids[batch_i, :]), s_rand) # 调用search函数找到主体s的起始位置 s_rand_idx = 1
  87. batch_subject_ids[batch_i, :] = [s_rand_idx, s_rand_idx + len(s_rand) - 1] # 找到主体s的终止位置
  88. for i in range(len(spo_list)): # 对spo_list进行遍历,如果主体、客体都在文本中存在,则记录下主体位置和关系,同时如果该主体与随机主体相同,则记录下客体的位置
  89. spo = spo_list[i] # spo = ["18种氨基酸葡萄糖注射液", "不良反应", "呕吐"]
  90. s = self.tokenizer.encode(spo[0])[1:-1] # s = [8123 4905 3710 1825 7000 5868 5843 5131 3800 2198 3890] 不取101和102
  91. p = config.prediction2id[spo[1]] # p = 8(预测的类别不良反应)
  92. o = self.tokenizer.encode(spo[2])[1:-1] # o = [1445 1402]
  93. s_idx = self.search(list(batch_token_ids[batch_i]), s) # s主体起始的位置 s_idx = 1
  94. o_idx = self.search(list(batch_token_ids[batch_i]), o) # o客体起始的位置 o_idx = 31
  95. if s_idx != -1 and o_idx != -1: # 主体和客体都存在的话
  96. batch_subject_labels[batch_i, s_idx, 0] = 1 # 要预测每一个token是不是主体的起始和终止位置
  97. batch_subject_labels[batch_i, s_idx + len(s) - 1, 1] = 1 # 0是起始位置,1是终止位置
  98. if s_idx == s_rand_idx:
  99. batch_object_labels[batch_i, o_idx, p, 0] = 1 # 记录客体o的开始位置
  100. batch_object_labels[batch_i, o_idx + len(o) - 1, p, 1] = 1 # 记录客体o的终止位置
  101. batch_i += 1
  102. if batch_i == batch_size or i == idxs[-1]:
  103. # 生成一个批次的数据和标签, 重置数据数组,准备处理下一个批次
  104. yield batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels
  105. batch_token_ids[:, :] = 0
  106. batch_mask_ids[:, :] = 0
  107. batch_subject_ids[:, :] = 0
  108. batch_subject_labels[:, :, :] = 0
  109. batch_object_labels[:, :, :, :] = 0
  110. batch_i = 0
  111. def get_stream(self):
  112. return cycle(self.process_data()) # 定义获取数据流的方法,使用cycle不断循环调用process_data()生成的批次数据
  113. def __iter__(self):
  114. return self.get_stream() # 返回数据流

4、主体预测过程

预测主体代码如下:

  1. class Model4s(nn.Module):
  2. """
  3. 预测主体:每个位置都要预测一下是不是一个主体,即找主体所在的位置,S开始是不是主体,E结束是不是主体
  4. """
  5. def __init__(self, hidden_size=768): # 初始化方法,接受隐藏层维度作为参数,默认为768
  6. super(Model4s, self).__init__()
  7. self.bert = BertModel.from_pretrained(config.PATH_BERT) # 加载预训练的BERT模型
  8. for param in self.bert.parameters(): # 设置BERT模型的所有参数可训练
  9. param.requires_grad = True
  10. self.dropout = nn.Dropout(p=0.2) # 定义Dropout层,用于防止过拟合
  11. self.linear = nn.Linear(in_features=hidden_size, out_features=2, bias=True) # 定义全连接层,用于进行二分类任务,输出为2,表示判断每个位置是否是主体的起始或结束
  12. self.sigmoid = nn.Sigmoid() # 定义Sigmoid激活函数,用于将输出映射到[0, 1]之间
  13. def forward(self, input_ids, input_mask, segment_ids, hidden_size=768): # 定义前向传播方法
  14. hidden_states = self.bert(input_ids, # 使用BERT模型进行前向传播,得到隐藏层的输出
  15. attention_mask=input_mask,
  16. token_type_ids=segment_ids)[0] # 隐藏特征(batch_size, sequence_length, hidden_size)
  17. output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(2) # 对BERT的隐藏层输出进行Dropout操作,然后通过全连接层和Sigmoid激活函数,得到预测结果。使用pow(2)进行平方项操作,用于筛选概率值
  18. return output, hidden_states # 返回预测结果和BERT模型隐藏层的输出

5、主体与客体关系预测过程

 预测主体与客体之间的关系代码如下:

  1. class Model4po(nn.Module):
  2. """
  3. 预测主体与客体之间的关系(抽象)
  4. """
  5. def __init__(self, num_p=config.num_p, hidden_size=768): # 初始化方法,接受关系类别数num_p和隐藏层维度hidden_size作为参数,默认num_p为23,hidden_size为768
  6. super(Model4po, self).__init__()
  7. self.dropout = nn.Dropout(p=0.4) # 定义Dropout层,用于防止过拟合
  8. self.linear = nn.Linear(in_features=hidden_size, out_features=num_p * 2, bias=True) # 定义全连接层,用于进行关系抽取任务,输出为num_p*2,表示预测每个位置的关系
  9. self.sigmoid = nn.Sigmoid() # Sigmoid激活函数,用于将输出映射到[0, 1]之间
  10. def forward(self, hidden_states, batch_subject_ids, input_mask): # 定义前向传播方法,接受隐藏层输出hidden_states、主体位置batch_subject_ids和输入掩码input_mask
  11. all_s = torch.zeros((hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]),
  12. dtype=torch.float32) # 初始化一个全零的张量,用于存储主体的特征
  13. for b in range(hidden_states.shape[0]): # 遍历batch中的每个样本
  14. s_start = batch_subject_ids[b][0] # 获取当前样本主体的起始位置
  15. s_end = batch_subject_ids[b][1] # 获取当前样本主体的终止位置
  16. s = hidden_states[b][s_start] + hidden_states[b][s_end] # 得到当前样本主体的特征表示,即主体起始特征和终止特征的和
  17. cue_len = torch.sum(input_mask[b]) # 计算当前样本的实际长度,用于确定主体在序列中的有效范围
  18. all_s[b, :cue_len, :] = s # 将主体的特征填充到主体位置上
  19. hidden_states += all_s # 将每一个位置的实际特征都更新为:自身特征 + 主体特征
  20. output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(4) # 对经过Dropout后的隐藏层输出进行全连接和Sigmoid激活,得到关系抽取的预测结果,并进行平方项筛选
  21. return output # 返回关系抽取的预测结果,维度为(batch_size, max_seq_len, num_p*2)

6、实际训练过程

实现训练过程代码如下:

  1. def train(train_data_loader, model4s, model4po, optimizer):
  2. """
  3. 实现训练过程:定义训练函数,接收训练数据加载器、model4s、model4po和优化器作为输入
  4. """
  5. for epoch in range(config.EPOCH): # 遍历每个epoch进行训练
  6. begin_time = time.time()
  7. # 将模型切换为训练模式
  8. model4s.train()
  9. model4po.train()
  10. train_loss = 0
  11. # 遍历训练数据加载器,获取每个批次的数据
  12. for bi, batch in enumerate(train_data_loader):
  13. if bi >= len(train_data_loader) // config.batch_size:
  14. break # 防止越界,提前结束循环
  15. # 数据转换为tensor,将其用于模型训练
  16. batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = batch
  17. batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long)
  18. batch_mask_ids = torch.tensor(batch_mask_ids, dtype=torch.long)
  19. batch_segment_ids = torch.tensor(batch_segment_ids, dtype=torch.long)
  20. batch_subject_labels = torch.tensor(batch_subject_labels, dtype=torch.float)
  21. batch_object_labels = torch.tensor(batch_object_labels, dtype=torch.float).view(config.batch_size,
  22. config.max_seq_len,
  23. config.num_p * 2)
  24. batch_subject_ids = torch.tensor(batch_subject_ids, dtype=torch.int)
  25. # model4s前向传播,得到预测结果和隐藏层输出
  26. batch_subject_labels_pred, hidden_states = model4s(batch_token_ids, batch_mask_ids, batch_segment_ids)
  27. # 计算4s模型的损失
  28. loss4s = loss_fn(batch_subject_labels_pred, batch_subject_labels.to(torch.float32))
  29. loss4s = torch.mean(loss4s, dim=2, keepdim=False) * batch_mask_ids
  30. loss4s = torch.sum(loss4s)
  31. loss4s = loss4s / torch.sum(batch_mask_ids)
  32. # model4po前向传播,得到预测结果
  33. batch_object_labels_pred = model4po(hidden_states, batch_subject_ids, batch_mask_ids)
  34. # 计算model4po的损失
  35. loss4po = loss_fn(batch_object_labels_pred, batch_object_labels.to(torch.float32))
  36. loss4po = torch.mean(loss4po, dim=2, keepdim=False) * batch_mask_ids
  37. loss4po = torch.sum(loss4po)
  38. loss4po = loss4po / torch.sum(batch_mask_ids)
  39. loss = loss4s + loss4po
  40. optimizer.zero_grad() # 清空过往梯度
  41. loss.backward() # 反向传播,计算当前梯度
  42. optimizer.step() # 根据梯度更新网络参数
  43. # 累加训练损失
  44. train_loss += float(loss.item())
  45. # 打印当前批次的损失
  46. print('batch:', bi, 'loss:', float(loss.item()))
  47. # 打印最终训练损失和训练耗时
  48. print('final train_loss:', train_loss / len(train_data_loader) * config.batch_size, 'cost time:',
  49. time.time() - begin_time)
  50. del train_data_loader # 删除训练数据加载器
  51. gc.collect();
  52. # 返回model4s、model4po和优化器的状态字典
  53. return {
  54. "model4s_state_dict": model4s.state_dict(),
  55. "model4po_state_dict": model4po.state_dict(),
  56. "optimizer_state_dict": optimizer.state_dict(),
  57. }

7、损失函数计算过程

 定义损失函数代码如下:

  1. def loss_fn(pred, target):
  2. """
  3. 定义损失函数
  4. """
  5. # 使用二元交叉熵损失函数,reduction='none'表示不对每个样本的损失进行求和
  6. loss_fct = nn.BCELoss(reduction='none')
  7. # 计算预测值与目标值之间的损失,返回损失值
  8. return loss_fct(pred, target)

8、模型评估过程

评估模型在给定数据集上的性能代码如下:

  1. def evaluate(data, is_print, model4s, model4po):
  2. """
  3. 评估模型在给定数据集上的性能
  4. """
  5. X, Y, Z = 1e-10, 1e-10, 1e-10 # 初始化变量,用于计算F1、Precision和Recall,避免分母为零的情况
  6. for d in data: # 遍历数据集中的每个样本
  7. R = set([SPO(spo) for spo in extract_spoes(d['text'], model4s, model4po)]) # 使用模型提取出的extract_spoes函数获取三元组集合R
  8. T = set([SPO(spo) for spo in d['spo_list']]) # 使用真实标签中的三元组集合T
  9. if is_print:
  10. print('text:', d['text']) # 打印当前样本的文本内容
  11. print('R:', R) # 打印模型提取出的三元组集合R
  12. print('T:', T) # 打印真实标签中的三元组集合T
  13. X += len(R & T) # 计算模型提取出的三元组集合中正确的个数
  14. Y += len(R) # 计算模型提取出的三元组总个数
  15. Z += len(T) # 计算真实标签中的三元组总个数
  16. f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z # 计算F1、Precision和Recall
  17. return f1, precision, recall # 返回评估结果

四、程序运行结果

运行程序代码如下:

  1. if __name__ == "__main__":
  2. # 读取训练数据文件
  3. with open(config.PATH_TRAIN, 'r', encoding="utf-8", errors='replace') as f:
  4. data = json.load(f) # 加载训练数据
  5. # 打开一个新的文件,将加载的数据以JSON格式写入,以便进行格式修改和保存
  6. f1 = open("../train.json", "w", encoding="utf-8")
  7. json.dump(data, f1, ensure_ascii=False, indent=True) # 将数据以JSON格式写入新文件
  8. print("finish")
  9. # 加载关系模型的schema,该schema用于关系抽取模型中的标签映射
  10. load_schema(config.PATH_SCHEMA)
  11. # 调用load_model()函数加载命名实体识别和关系抽取的模型,得到model4s和model4po两个模型
  12. model4s, model4po = load_model()
  13. # 输入一个包含待处理文本的字符串变量text
  14. text = "据报道称,新冠肺炎患者经常会发热、咳嗽,少部分患者会胸闷、=乏力,其病因包括: 1.自身免疫系统缺陷\n2.人传人。"
  15. # 调用get_triples()函数,传入文本和加载的两个模型,得到文本中的命名实体和关系三元组
  16. res = get_triples(text, model4s, model4po)
  17. print(res) # 打印最终结果

1、加载模型

命名实体识别和关系抽取模型加载代码如下:

  1. def load_model():
  2. """
  3. 加载命名实体识别和关系抽取的模型
  4. """
  5. load_schema(config.PATH_SCHEMA) # 调用函数load_schema(),载入预定义的关系模式
  6. checkpoint = torch.load(config.PATH_MODEL, map_location='cpu') # 从指定路径加载模型的检查点,并指定在CPU上进行计算
  7. model4s = Model4s() # 创建命名实体识别模型对象,实例化了前面提到的Model4s类
  8. model4s.load_state_dict(checkpoint['model4s_state_dict']) # 从加载的检查点中载入命名实体识别模型的权重参数
  9. model4po = Model4po() # 创建关系抽取模型对象,实例化了前面提到的Model4po类
  10. model4po.load_state_dict(checkpoint['model4po_state_dict']) # 从加载的检查点中载入关系抽取模型的权重参数
  11. return model4s, model4po # 返回加载的命名实体识别模型和关系抽取模型

2、接收文本内容以及命名实体识别模型和关系抽取模型

 接收三组数据代码如下:

  1. def get_triples(content, model4s, model4po):
  2. """
  3. 接收文本内容content以及命名实体识别模型model4s和关系抽取模型model4po作为参数
  4. """
  5. if len(content) == 0:
  6. return [] # 如果输入文本为空,则返回空列表
  7. text_list = content.split('。')[:-1] # 将输入文本按句号分割成列表,去掉最后一个空句号,得到一个句子的列表
  8. res = [] # 初始化结果列表,用于保存每个句子的命名实体及其关系
  9. for text in text_list: # 遍历每个句子
  10. if len(text) > 128: # 如果句子长度大于128,则截断为前128个字符
  11. text = text[:128]
  12. triples = extract_spoes(text, model4s, model4po) # 调用extract_spoes()函数,提取句子中的命名实体及其关系
  13. res.append({
  14. 'text': text, # 将原始句子文本保存到结果字典中
  15. 'triples': triples # 将提取得到的命名实体及其关系保存到结果字典中
  16. })
  17. return res # 返回包含每个句子命名实体及其关系的列表

3、提取输入文本中的命名实体及其关系

 提取输入文本中的命名实体及其关系三元组代码如下:

  1. def extract_spoes(text, model4s, model4po):
  2. """
  3. 该函数用于从输入文本中提取命名实体及其关系三元组
  4. """
  5. with torch.no_grad(): # 关闭梯度计算,减小内存消耗
  6. tokenizer = config.tokenizer # 加载文本处理器(Tokenizer)
  7. max_seq_len = config.max_seq_len # 获取模型所能处理的最大序列长度
  8. token_ids = torch.tensor(tokenizer.encode(text, max_length=max_seq_len, pad_to_max_length=True, add_special_tokens=True)).view(1, -1) # 将文本转换为模型可接受的输入格式
  9. if len(text) > max_seq_len - 2:
  10. text = text[:max_seq_len - 2]
  11. mask_ids = torch.tensor([1] * (len(text) + 2) + [0] * (max_seq_len - len(text) - 2)).view(1, -1) # 创建用于掩码的张量,标记文本的有效部分
  12. segment_ids = torch.tensor([0] * max_seq_len).view(1, -1) # 创建用于标识文本片段的张量
  13. subject_labels_pred, hidden_states = model4s(token_ids, mask_ids, segment_ids) # 使用命名实体识别模型预测主体位置
  14. subject_labels_pred = subject_labels_pred.cpu() # 将预测结果移动到CPU上
  15. subject_labels_pred[0, len(text) + 2:, :] = 0 # 将超出文本长度的部分概率置零
  16. start = np.where(subject_labels_pred[0, :, 0] > 0.4)[0]
  17. end = np.where(subject_labels_pred[0, :, 1] > 0.4)[0] # 根据预测概率大于0.4的位置,获取主体的起始和结束位置
  18. subjects = [] # 存储主体的列表
  19. for i in start: # 遍历主体的起始和结束位置
  20. j = end[end >= i]
  21. if len(j) > 0:
  22. j = j[0]
  23. subjects.append((i, j))
  24. if len(subjects) == 0:
  25. return []
  26. subject_ids = torch.tensor(subjects).view(1, -1)
  27. spoes = []
  28. for s in subjects:
  29. object_labels_pred = model4po(hidden_states, subject_ids, mask_ids) # 使用关系抽取模型预测关系三元组
  30. object_labels_pred = object_labels_pred.view((1, max_seq_len, config.num_p, 2)).cpu() # 调整关系抽取模型的预测结果的形状
  31. object_labels_pred[0, len(text) + 2:, :, :] = 0 # 将超出文本长度的部分概率置零
  32. start = np.where(object_labels_pred[0, :, :, 0] > 0.4)
  33. end = np.where(object_labels_pred[0, :, :, 1] > 0.4) # 根据预测概率大于0.4的位置,获取关系三元组的起始和结束位置
  34. for _start, predicate1 in zip(*start): # 遍历关系三元组的起始和结束位置
  35. for _end, predicate2 in zip(*end):
  36. if _start <= _end and predicate1 == predicate2:
  37. spoes.append((s, predicate1, (_start, _end))) # 将提取到的主体、关系、客体三元组添加到结果列表
  38. break
  39. id_str = ['[CLS]'] # 初始化标识字符串列表,用于存储处理后的文本
  40. i = 1
  41. index = 0
  42. while i < token_ids.shape[1]: # 循环处理模型输入的标识符
  43. if token_ids[0][i] == 102: # 检查是否到达标志序列的终点
  44. break
  45. word = tokenizer.decode(token_ids[0, i:i + 1]) # 解码标识符并去除特殊字符
  46. word = re.sub('#+', '', word)
  47. if word != '[UNK]': # 检查是否为未知标记
  48. id_str.append(word)
  49. index += len(word)
  50. i += 1
  51. else:
  52. j = i + 1
  53. while j < token_ids.shape[1]: # 初始化索引j,寻找下一个标志
  54. if token_ids[0][j] == 102: # 检查是否到达标志序列的终点
  55. break
  56. word_j = tokenizer.decode(token_ids[0, j:j + 1])
  57. if word_j != '[UNK]':
  58. break
  59. j += 1
  60. if token_ids[0][j] == 102 or j == token_ids.shape[1]:
  61. while i < j - 1: # 更新标识字符串列表,排除填充标志
  62. id_str.append('')
  63. i += 1
  64. id_str.append(text[index:]) # 将文本添加到标识字符串列表
  65. i += 1
  66. break
  67. else:
  68. index_end = text[index:].find(word_j)
  69. word = text[index:index + index_end]
  70. id_str.append(word)
  71. index += index_end
  72. i += 1
  73. res = [] # 初始化最终结果列表
  74. for s, p, o in spoes: # 遍历提取到的关系三元组
  75. s_start = s[0] # 获取主体的起始和结束位置
  76. s_end = s[1]
  77. sub = ''.join(id_str[s_start:s_end + 1]) # 从标识字符串列表中获取主体
  78. o_start = o[0]
  79. o_end = o[1] # 获取客体的起始和结束位置
  80. obj = ''.join(id_str[o_start:o_end + 1]) # 从标识字符串列表中获取客体
  81. res.append((sub, config.id2predicate[p], obj)) # 将构建好的关系三元组添加到最终结果列表
  82. return res # 返回最终结果列表

4、运行结果 

运行结果如下: 

  1. decode:[CLS]据报道称,新冠肺炎患者经常会发热、咳嗽,少部分患者会胸闷、乏力,其病因包括: 1.自身免疫系统缺陷2.人传人。[SEP][PAD][PAD][PAD][PAD]
  2. [
  3. {
  4. "text":"据报道称,新冠肺炎患者经常会发热、咳嗽,少部分患者会胸闷、乏力,其病因包括:1.自身免疫系统缺陷\n2.人传人"
  5. "triples": [
  6. [
  7. "新冠肺炎",
  8. "临床表现",
  9. "肺炎"
  10. ],
  11. [
  12. "新冠肺炎",
  13. "临床表现",
  14. "发热"
  15. ],
  16. [
  17. "新冠肺炎",
  18. "临床表现",
  19. "咳嗽"
  20. ],
  21. [
  22. "新冠肺炎",
  23. "临床表现",
  24. "胸闷"
  25. ],
  26. [
  27. "新冠肺炎",
  28. "临床表现",
  29. "乏力"
  30. ],
  31. [
  32. "新冠肺炎",
  33. "病因",
  34. "自身免疫系统缺陷"
  35. ],
  36. [
  37. "新冠肺炎",
  38. "病因",
  39. "人传人"
  40. ]
  41. ]
  42. }
  43. ]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/705610
推荐阅读
相关标签
  

闽ICP备14008679号