当前位置:   article > 正文

BiLSTM-Attention实现关系抽取(基于pytorch)_基于 bert-wwm 和 bilstm-attention 的关系抽取模型

基于 bert-wwm 和 bilstm-attention 的关系抽取模型

概述

虽然tensorflow2.0发布以来还是收获了一批用户,但是在自然语言处理领域,似乎pytorch见的更多一点。关系抽取是目前自然语言处理的主流任务之一,遗憾没能找到较新能用的开源代码。一方面是因为关系抽取任务的复杂性,目前数据集较少,且标注的成本极高,尤其是中文数据集,所以针对该任务的数据集屈指可数,这也限制了这方面的研究。另一方面,关系抽取任务的复杂性,程序多数不可通用。github上有pytorch版本的BiLSTM-attention的开源代码,然而基于python2且pytorch版本较低。目前没有基于python3,tf2的BiLSTM-Attention关系抽取任务的开源代码。我在这篇博客中会写使用python3,基于pytorch框架实现BiLSTM-Attention进行关系抽取的主要代码(无关紧要的就不写啦)。(学生党还是弃坑tensorflow1.x吧,一口老血。。。)

关系抽取

其实关系抽取可以归为信息抽取的一部分。信息抽取是当前自然语言处理的热点之一。信息抽取是知识图谱,文本摘要等任务的核心环节,但是就目前的研究来看,当前的技术仍不成熟,所消耗的资源较多且研究结果差强人意。对于构建知识图谱来说,实体识别,关系抽取,实体融合是不可缺少的要素。当前,联合关系抽取有许多经典模型,但是效果一般。在可以保证实体识别的高准确率的情况下,还是建议使用pipeline方法,即先识别实体,后进行实体之间的关系抽取。本文介绍在已经有实体的基础上,进行关系抽取的经典模型,BiLSTM-Attention,该模型在NLP中很多地方都有它的身影,尤其是文本分类任务中。进行关系抽取时,也是把句子进行分类任务,这种情况下,关系抽取也叫做关系分类。理论基础来源于这篇文章:文章地址
关于LSTM和attention我就不多赘述了,网上的资料很多。我们直接来看一下架构:
在这里插入图片描述这是文章中的架构图。其实也很简单,字符经过嵌入后传给LSTM层,编码之后经过Attention层,然后进行目标的预测。这一看就是个最简单的文本分类的结构。那么关系抽取是怎么解决的呢?关系抽取其实就是在嵌入时,加入了实体的特征,与句子特征融合起来,丢给神经网络进行关系分类。

数据集

数据集展示在这里插入图片描述这个数据集是关系抽取中最常见的数据集,人物关系抽取。第一列,第二列是实体,第三列是他们之间的关系,后面是两个实体所处的句子。总共11个类别+unknown,在文件relation2id中。
在这里插入图片描述### 数据预处理

def get_label_distribution(relation_file_path,data_file_path):
    relation2id = {}
    with open(relation_file_path, "r", encoding="utf-8") as fr:
        for line in fr.readlines():
            line = line.strip().split(" ")
            relation2id[line[0]] = int(line[1])
    import pandas as pd
    label = []
    with open(data_file_path, encoding='utf-8') as fr:
        for line in fr.readlines():
            line = line.split("\t")
            label.append(relation2id[line[2]])
    df = pd.Series(label).value_counts()
    return df
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

上述代码是为了得到标签的分布,读取文件后将标签信息写入relation2id的字典中。label 中写入的是数据集中的标签。

def flatten_lists(lists):
    flatten_list = []
    for l in lists:
        if type(l) == list:
            flatten_list += l
        else:
            flatten_list.append(l)
    return flatten_list
def flat_gen(x):
    def is_elment(el):
        return not(isinstance(el,collections.Iterable) and not isinstance(el,str))
    for el in x:
        if is_elment(el):
            yield el
        else:
            yield from flat_gen(el)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

上述代码是将整个数据文件转换成单行列表。
评价模型部分的代码省略,基本都是一个套路。

模型构建

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1)  

class BiLSTM_ATT(nn.Module):
    def __init__(self,input_size,output_size,config,pre_embedding):
        super(BiLSTM_ATT,self).__init__()
        self.batch = config['BATCH']

        self.input_size = input_size
        self.embedding_dim = config['EMBEDDING_DIM']
        
        self.hidden_dim = config['HIDDEN_DIM']
        self.tag_size = output_size 
        
        self.pos_size = config['POS_SIZE']
        self.pos_dim = config['POS_DIM'] 
        
        self.pretrained = config['pretrained']

        if self.pretrained:
            self.word_embeds = nn.Embedding.from_pretrained(torch.FloatTensor(pre_embedding),freeze=False)
        else:
            self.word_embeds = nn.Embedding(self.input_size,self.embedding_dim)

        self.pos1_embeds = nn.Embedding(self.pos_size,self.pos_dim) 
        self.pos2_embeds = nn.Embedding(self.pos_size,self.pos_dim) 
        self.dense = nn.Linear(self.hidden_dim,self.tag_size,bias=True)
        self.relation_embeds = nn.Embedding(self.tag_size,self.hidden_dim)
        self.lstm = nn.LSTM(input_size=self.embedding_dim+self.pos_dim*2,hidden_size=self.hidden_dim//2,num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(self.hidden_dim,self.tag_size)

        self.dropout_emb = nn.Dropout(p=0.5)
        self.dropout_lstm = nn.Dropout(p=0.5)
        self.dropout_att = nn.Dropout(p=0.5)
        
        self.hidden = self.init_hidden()
        self.att_weight = nn.Parameter(torch.randn(self.batch,1,self.hidden_dim))
        self.relation_bias = nn.Parameter(torch.randn(self.batch,self.tag_size,1))
        
    def init_hidden(self):
        return torch.randn(2, self.batch, self.hidden_dim // 2)
        
    def init_hidden_lstm(self):
        return (torch.randn(2, self.batch, self.hidden_dim // 2),
                torch.randn(2, self.batch, self.hidden_dim // 2))
  
    def attention(self,H):
        M = torch.tanh(H) # 非线性变换 size:(batch_size,hidden_dim,seq_len)
        a = F.softmax(torch.bmm(self.att_weight,M),dim=2) # a.Size : (batch_size,1,seq_len)
        a = torch.transpose(a,1,2) # (batch_size,seq_len,1)
        return torch.bmm(H,a) # (batch_size,hidden_dim,1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

LSTM的输入是实体1的位置信息+实体2的微信信息+嵌入信息。LSTM的output保存了最后一层的输出h。

    def forward(self,sentence,pos1,pos2):
        self.hidden = self.init_hidden_lstm()
        embeds = torch.cat((self.word_embeds(sentence),self.pos1_embeds(pos1),self.pos2_embeds(pos2)),dim=2)
        embeds = torch.transpose(embeds,0,1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
 
        lstm_out = lstm_out.permute(1,2,0)
        lstm_out = self.dropout_lstm(lstm_out)

        att_out = torch.tanh(self.attention(lstm_out ))
        relation = torch.tensor([i for i in range(self.tag_size)], dtype=torch.long).repeat(self.batch, 1)
        relation = self.relation_embeds(relation)
        out = torch.add(torch.bmm(relation, att_out), self.relation_bias)
        out = F.softmax(out,dim=1)
        return out.view(self.batch,-1) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

这一段主要是维度变换的工作,将数据处理成模型所需要的维度。上面有一些配置信息是在另一个文件夹中统一编写的,基本的模型就是这样。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/543275
推荐阅读
相关标签
  

闽ICP备14008679号