当前位置:   article > 正文

NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现_关系分类模型

关系分类模型

  本文将介绍关系分类模型R-BERT和该模型在人物关系数据集上的表现,以及该模型的Keras代码复现。

关系分类任务

  关系分类属于NLP任务中的文本分类,不同之处在于,关系分类提供了文本和实体。比如下面的例子:

亲戚 1837年6月20日,威廉四世辞世,他的侄女维多利亚即位。

其中两个实体在文本中用和包围着,人物关系为亲戚。

  在关系分类中,我们要注重文本特征,更要留意实体特征。常见的英文关系分类的数据集为SemEval 2010 Task 8、New York Times Corpus、WikiData dataset for Sentential Relation Extraction、NYT29、NYT24等,中文的关系分类数据集比较少,而且质量不高。
  关于SemEval 2010 Task 8数据集的实现模型及效果,可以参考:http://nlpprogress.com/english/relationship_extraction.html, 其中常见的实现模型如下:

  • Machince Learning: SVM, Word2Vec …
  • Dependency Models: BRCNN, DRNN …
  • CNN-based Models: Multi-Attention CNN, Attention CNN, PCNN+ATT …
  • BERT-based Models: R-BERT, Matching-the-Blanks …

  本文将介绍R-BERT模型。

模型介绍

  R-BERT模型是Alibaba Group (U.S.) Inc的两位研究者在2019年5月的论文Enriching Pre-trained Language Model with Entity Information for Relation Classification,该模型在SemEval 2010 Task 8数据集上的F1值为89.25%,只比现有的SOTA模型低了0.25%。
  R-BERT很好地融合了文本特征以及两个实体在文本中的特征,简单来说,该模型主要是BERT模型中的三个向量的融合:

  • [CLS]对应的向量
  • 实体1的平均向量
  • 实体2的平均向量

  下面将详细讲解R-BERT的具体模型结构。

模型结构

  R-BERT的具体模型结构如下图:
R-BERT模型结构图

  一图胜千言。从上述的模型结构图中,我们将模型结构分解步骤如下:

  1. 将文本接入BERT模型,获取[CLS] token的对应向量、实体1的在BERT输出层中的平均向量、实体2的在BERT输出层中的平均向量;
  2. 将上述三个向量分别接Drouput层、Tanh激活层以及全连接层;
  3. 再将步骤2输出的三个向量进行拼接(concatenate);
  4. 最后接Dropout层和全连接层,用Softmax作为多分类的激活函数。

此外,需要注意的是,输入文本中没有[SEP]这个token。
  论文中并没有给出更多的实现细节,需要深入到代码中去查看。网上已经有人给出了Torch框架的实现R-BERT的代码,参考网址为:https://github.com/monologg/R-BERT。

Torch实现

  Torch框架的实现R-BERT的代码(模型部分)如下:

import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel


class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
        super(FCLayer, self).__init__()
        self.use_activation = use_activation
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, output_dim)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.dropout(x)
        if self.use_activation:
            x = self.tanh(x)
        return self.linear(x)


class RBERT(BertPreTrainedModel):
    def __init__(self, config, args):
        super(RBERT, self).__init__(config)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.num_labels = config.num_labels

        self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
        self.entity_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
        self.label_classifier = FCLayer(
            config.hidden_size * 3,
            config.num_labels,
            args.dropout_rate,
            use_activation=False,
        )

    @staticmethod
    def entity_average(hidden_output, e_mask):
        """
        Average the entity hidden state vectors (H_i ~ H_j)
        :param hidden_output: [batch_size, j-i+1, dim]
        :param e_mask: [batch_size, max_seq_len]
                e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
        :return: [batch_size, dim]
        """
        e_mask_unsqueeze = e_mask.unsqueeze(1)  # [b, 1, j-i+1]
        length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)  # [batch_size, 1]

        # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
        sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
        avg_vector = sum_vector.float() / length_tensor.float()  # broadcasting
        return avg_vector

    def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
        outputs = self.bert(
            input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        # Average
        e1_h = self.entity_average(sequence_output, e1_mask)
        e2_h = self.entity_average(sequence_output, e2_mask)

        # Dropout -> tanh -> fc_layer (Share FC layer for e1 and e2)
        pooled_output = self.cls_fc_layer(pooled_output)
        e1_h = self.entity_fc_layer(e1_h)
        e2_h = self.entity_fc_layer(e2_h)

        # Concat -> fc_layer
        concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
        logits = self.label_classifier(concat_h)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        # Softmax
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            outputs = (loss,) + outputs

        return outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87

  该项目是在SemEval 2010 Task 8数据集实现的,笔者将其在自己的人物关系分类数据集上进行测试,最终在测试集上的评估结果如下:

# Model: chinese-roberta-wwm-ext, weighted avgage F1 = 85.35%
# Model: chinese-roberta-wwm-ext-large, weighted avgage F1 = 87.22%
  • 1
  • 2

Model: chinese-roberta-wwm-ext-large, 详细的评估结果如下:

                precision    recall  f1-score   support

     unknown      0.8756    0.8421    0.8585       209
         上下级    0.7297    0.8710    0.7941        31
          亲戚     0.8421    0.6667    0.7442        24
        兄弟姐妹    0.8333    0.8824    0.8571        34
          合作     0.9074    0.8305    0.8673        59
          同人     0.9744    0.9744    0.9744        39
          同学     0.9130    0.8750    0.8936        24
          同门     0.9630    1.0000    0.9811        26
          夫妻     0.8372    0.9114    0.8727        79
          好友     0.8438    0.9000    0.8710        30
          师生     0.8378    0.8378    0.8378        37
          情侣     0.8125    0.8387    0.8254        31
          父母     0.8931    0.9141    0.9035       128
          祖孙     0.9545    0.8400    0.8936        25

    accuracy                         0.8724       776
   macro avg     0.8727    0.8703    0.8696       776
weighted avg     0.8743    0.8724    0.8722       776
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

  R-BERT模型在人物关系数据集上的Github项目为
R-BERT_for_people_relation_extraction
。下面将介绍R-BERT模型的Keras框架复现。

Keras复现

  R-BERT模型的Keras框架复现(模型部分)的代码如下:

# -*- coding: utf-8 -*-
# main architecture of R-BERT
from keras.models import Model
from keras.utils import plot_model
from keras.layers import Input, Lambda, Dense, Dropout, concatenate, Dot
from keras_bert import load_trained_model_from_checkpoint


# model structure of R-BERT
class RBERT(object):
    def __init__(self, config_path, checkpoint_path, maxlen, num_labels):
        self.config_path = config_path
        self.checkpoint_path = checkpoint_path
        self.maxlen = maxlen
        self.num_labels = num_labels

    def create_model(self):
        # BERT model
        bert_model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path, seq_len=None)
        for layer in bert_model.layers:
            layer.trainable = True
        x1_in = Input(shape=(self.maxlen,))
        x2_in = Input(shape=(self.maxlen,))
        bert_layer = bert_model([x1_in, x2_in])

        # get three vectors
        cls_layer = Lambda(lambda x: x[:, 0])(bert_layer)    # 取出[CLS]对应的向量
        e1_mask = Input(shape=(self.maxlen,))
        e2_mask = Input(shape=(self.maxlen,))
        e1_layer = self.entity_average(bert_layer, e1_mask)  # 取出实体1对应的向量
        e2_layer = self.entity_average(bert_layer, e2_mask)  # 取出实体2对应的向量

        # dropout -> linear -> concatenate
        output_dim = cls_layer.shape[-1].value
        cls_fc_layer = self.crate_fc_layer(cls_layer, output_dim, dropout_rate=0.1)
        e1_fc_layer = self.crate_fc_layer(e1_layer, output_dim, dropout_rate=0.1)
        e2_fc_layer = self.crate_fc_layer(e2_layer, output_dim, dropout_rate=0.1)
        concat_layer = concatenate([cls_fc_layer, e1_fc_layer, e2_fc_layer], axis=-1)

        # FC layer for classification
        output = Dense(self.num_labels, activation="softmax")(concat_layer)
        model = Model([x1_in, x2_in, e1_mask, e2_mask], output)
        model.summary()
        return model

    @staticmethod
    def crate_fc_layer(input_layer, output_dim, dropout_rate=0.0, activation_func="tanh"):
        dropout_layer = Dropout(rate=dropout_rate)(input_layer)
        linear_layer = Dense(output_dim, activation=activation_func)(dropout_layer)
        return linear_layer

    @staticmethod
    def entity_average(hidden_output, e_mask):
        """
        Average the entity hidden state vectors (H_i ~ H_j)
        :param hidden_output: BERT hidden output
        :param e_mask:
                e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]/num_of_ones
        :return: entity average layer
        """
        avg_layer = Dot(axes=1)([e_mask, hidden_output])
        return avg_layer
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

总结

  R-BERT模型再次见证了BERT等预训练模型的强大。该模型的实现思路比较简单,也取得了很不错的效果,是关系分类任务的一大突破。
  当然对笔者来说,也有种重要的意义:第一次自己复现了论文代码,虽然有Torch代码可以参考。
  本文分享到此结束,感谢阅读~
  2021年4月1日于上海杨浦,次日大雾迷城~

参考文献

  • NLP-progress Relation Extraction: http://nlpprogress.com/english/relationship_extraction.html
  • Huggingface Transformers: https://github.com/huggingface/transformers
  • https://github.com/wang-h/bert-relation-classification
  • R-BERT: https://github.com/monologg/R-BERT
  • Enriching Pre-trained Language Model with Entity Information for Relation Classification: https://arxiv.org/pdf/1905.08284.pdf
  • Chinese-BERT-wwm: https://github.com/ymcui/Chinese-BERT-wwm
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/312466
推荐阅读
相关标签
  

闽ICP备14008679号