当前位置:   article > 正文

BERT实现关系分类抽取(pytorch)_token_ids

token_ids

前言

        前两天在GitHub上看到这样一个关系分类抽取模型(地址:GitHub),项目的任务是给定句子和句子中的两个实体,判断这两个实体之间的关系,项目文件中带有质量较高的数据集。项目的思路大致是将关系抽取转化成对两个实体的关系进行分类,感觉算是实体抽取的入门,所以在此记录一下。

        最近看到一个观点说是我们写一个深度学习任务的项目,大致可以分为三块:model,也就是要定义我们模型各层的结构;dataset,把要训练的数据集改写成model需要的格式;train,定义训练、验证和测试的参数和过程。有人说先写dataset,有人说先写model,但是我觉得还是model更重要,dataset是由model的输入层和输出层决定的,所以只要我们自己明确了model的输入和输出,dataset通过一些Python基础很容易就实现了。但是model才是整个项目的核心所在,所有亮点和创新都体现在这里,一个小小的改动可能就会对最后结果造成很大影响。train部分相对来说最固定,基本都是一些重复性的代码。所以在这里对这个项目的model进行一下解释。

        这个项目的model在relation_extraction/model.py文件中,我把源码贴在这里,感兴趣的可以通过这个model来自己补全dataset和train。

  1. import torch
  2. import torch.nn as nn
  3. from transformers import BertModel
  4. class SentenceRE(nn.Module):
  5. def __init__(self, hparams):
  6. super(SentenceRE, self).__init__()
  7. self.pretrained_model_path = hparams.pretrained_model_path or 'bert-base-chinese'
  8. self.embedding_dim = hparams.embedding_dim
  9. self.dropout = hparams.dropout
  10. self.tagset_size = hparams.tagset_size
  11. self.bert_model = BertModel.from_pretrained(self.pretrained_model_path)
  12. self.dense = nn.Linear(self.embedding_dim, self.embedding_dim)
  13. self.drop = nn.Dropout(self.dropout)
  14. self.activation = nn.Tanh()
  15. self.norm = nn.LayerNorm(self.embedding_dim * 3)
  16. self.hidden2tag = nn.Linear(self.embedding_dim * 3, self.tagset_size)
  17. def forward(self, token_ids, token_type_ids, attention_mask, e1_mask, e2_mask):
  18. sequence_output, pooled_output = self.bert_model(input_ids=token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)
  19. # 每个实体的所有token向量的平均值
  20. e1_h = self.entity_average(sequence_output, e1_mask)
  21. e2_h = self.entity_average(sequence_output, e2_mask)
  22. e1_h = self.activation(self.dense(e1_h))
  23. e2_h = self.activation(self.dense(e2_h))
  24. # [cls] + 实体1 + 实体2
  25. concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=1)
  26. concat_h = self.norm(concat_h)
  27. logits = self.hidden2tag(self.drop(concat_h))
  28. return logits
  29. @staticmethod
  30. def entity_average(hidden_output, e_mask):
  31. """
  32. Average the entity hidden state vectors (H_i ~ H_j)
  33. :param hidden_output: [batch_size, max_len, dim]
  34. :param e_mask: [batch_size, max_seq_len]
  35. e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
  36. :return: [batch_size, dim]
  37. """
  38. # (batch_size,1,max_len)
  39. e_mask_unsqueeze = e_mask.unsqueeze(1)
  40. length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1]
  41. sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1) # [b, 1, max_len] * [b, max_len, dim] = [b, 1, dim] -> [b, dim]
  42. avg_vector = sum_vector.float() / length_tensor.float() # broadcasting
  43. return avg_vector

输入输出

        首先明确model的输入输出,前面已经介绍了本项目的任务是对给定的句子和句子中的两个实体判断关系,所以输入就是句子、实体1和实体2,输出就是关系编号。然后看forward函数的输入输出,forward函数的参数包括以下5个:

        1、token_ids:这个比较常见,标记着句子中每个字在词表中的位置。

        2、token_type_ids:区分两个句子的编码,但是在本次项目中只输入一个句子。

        3、attention_mask:标记哪些位置要进行self-attention操作,其他位置都是pad。

        4、e1_mask:标记第一个实体的位置。

        5、e2_mask:标记第二个实体的位置。

        输出的是模型对两个实体的关系预测结果。


中间层

        模型的中间层实际上体现在forward函数中,上面的初始化函数对各层的定义顺序没有要求,所以看forward函数才能真正掌握模型的数据流动过程。

        1、forward函数中首先是一个bert层,先将对整个句子进行编码,得到句中每个字的字向量sequence_output和整句的句向量pooled_output。这些向量都是transformer最后一层的输出。

sequence_output, pooled_output = self.bert_model(input_ids=token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)

        2、通过前面的e1_mask和e2_mask来找到两个实体每个字的字向量,然后每个实体的所有字向量相加后求均值。这部分工作体现在entity_average函数中,该函数的作用就是求一个实体中所有字的字向量均值,过程比较巧妙,所以对这个函数每句进行一些解释。函数的输入一个是句子的句向量hidden_output,维度为(batch_size,seq_l,embedding_dim),还有一个是实体的位置e_mask,这是一个(batch_size,seq_len)的向量,实体位置为1,其余位置为0。

                ①、第一步对e_mask向量进行升维,从(batch_size,seq_len)变为(batch_size,1,seq_len),这步的作用是对齐e_mask和hidden_output的维度,使得后面二者可以相乘。

e_mask_unsqueeze = e_mask.unsqueeze(1)

                ②、第二步求出每个实体的字数,使用的sum函数可以统计向量中值为1的元素数量,然后再进行升维。

length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)

                ③、第三步要计算实体所有字向量的平均值,用e_mask和hidden_output相乘,其中e_mask为0的位置,与非实体位置的字向量相乘,e_mask为1的位置与实体位置的每个字向量相乘,这样就实现了实体每个字的字向量相加。这里使用了bmm批量矩阵乘法函数,二者的维度分别为(batch_size,1,seq_l)和(batch_size,seq_len,embedding_dim),在理解bmm函数时可以先不考虑batch_size维度,因为三维矩阵可以看做是多个二维矩阵的结合,bmm函数每次分别取两个二维矩阵相乘,然后将所有结果组合成一个三维矩阵。也就是说:

[batch_size,1,max_len] × [batch_size,max_len,dim] 可以看做是

batch_size ×([1,max_len] × [max_len,dim])= batch_size × [1,dim] = [batch_size,1,dim]

然后再将中间那个维度去掉,得到最后结果(batch_size,embedd_dim)。

sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)

                ④、最后将相加后的向量除以实体长度,这样就实现了实体字向量的求均值。在这里就可以看出第二步对字数升维的作用,是为了对齐维度。

avg_vector = sum_vector.float() / length_tensor.float()

        3、求出的两个实体字向量均值经过激活函数后,与[CLS]位置的向量连接起来,然后经过一个归一化层和Dropout层,最后通过一个全连接层求分类。

  1. concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=1)
  2. concat_h = self.norm(concat_h)
  3. logits = self.hidden2tag(self.drop(concat_h))
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/378076
推荐阅读
相关标签
  

闽ICP备14008679号