当前位置:   article > 正文

XLNet实现超长文本分类

长文本分类

Bert只能处理长度小于512的序列,算上一些[CLS],[SEP],实际的长度要小于512。因此对于超长文本来说,Bert的效果可能一般,尤其是那些更加依赖于文档中后部分内容的下游任务。因此本文尝试使用transformers的XLNet提升超长文本多标签分类的效果。关于XLNet的介绍略。

预训练模型下载与加载

官网搜索自己想要的模型并下载对应pytorch版本的文件:
在这里插入图片描述
使用的时候可以参照官方文档给出的范例:

from transformers import XLNetTokenizer, XLNetModel
import torch
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased-path')
model = XLNetModel.from_pretrained('xlnet-base-cased-path')
  • 1
  • 2
  • 3
  • 4

其中参数是刚刚下载的文件夹所在的路径。

数据预处理

同Bert数据预处理一样,XLNet同样采样tokenize+convert to id的形式,并且对于我自己的数据,需要人为地补全<pad><csl><sep>。注意,在Bert里,这几位仁兄写作[PAD],[CLS],[SEP]。在官方文档里,是这样解释的:
在这里插入图片描述
注意在使用预训练模型并微调的时候,这几个字符是非常有必要的,尤其是<cls>,在数据预处理的时候不添加这个会导致模型性能有明显地下降。数据预处理的代码很简单:

def load_data():
    with open('../data/data.txt', 'r', encoding='utf-8') as rf:
        datas = [each.strip().split('\t') for each in rf.readlines()]
    # process data to Bert input
    Datas = []
    for data in datas:
        labels = sorted([int(a) for a in data[0].split('-')])
        labels = trans_label2n_hot(labels, 35)
        # notice there some differences from Bert:
        # <sep> for XLNet but [SEP] for Bert; <cls> for XLNet but [CLS] for Bert and <pad> for [PAD]
        sentence = '<cls>' + data[1].replace(' ','').replace('<SEP>','<sep>')
        tokens = tokenizer.tokenize(sentence)[:xlnet_cfg.max_len]
        if len(tokens) < xlnet_cfg.max_len:        # padding for max length
            tokens.extend(['<pad>'] * (xlnet_cfg.max_len - len(tokens)))
        ids = np.array(tokenizer.convert_tokens_to_ids(tokens))
        labels = torch.from_numpy(labels)
        ids = torch.from_numpy(ids)
        Datas.append([ids, labels])
    split = int(len(Datas) * xlnet_cfg.train_test_split)
    Trains = Datas[:split]
    Tests = Datas[split:]   # split+193
    train_loader = DataLoader(Trains, xlnet_cfg.batch_size, shuffle=True)
    test_loader = DataLoader(Tests, xlnet_cfg.batch_size, shuffle=True)
    print('data load finished! {}\t{}'.format(len(Trains), len(Tests)))
    return train_loader, test_loader
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

注意这里,因为是多标签的分类,所以data.txt中文件存储的方式如下:
在这里插入图片描述
前边的数字表示的是后边长本文所属的标签。

微调XLNet

微调的过程也和Bert的一样,稍有不同的是XLNet的输出与Bert输出不同。同样,官方文档给出了输出的内容:
在这里插入图片描述
具体的代码如下:

class MyXLNet(nn.Module):
    def __init__(self, num_classes=35, alpha=0.5):
        self.alpha = alpha
        super(MyXLNet, self).__init__()
        self.net = XLNetModel.from_pretrained(xlnet_cfg.xlnet_path).cuda()
        for name, param in self.net.named_parameters():
            if 'layer.11' in name or 'layer.10' in name or 'layer.9' in name or 'layer.8' in name or 'pooler.dense' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.MLP = nn.Sequential(
            nn.Linear(768, num_classes, bias=True),
        ).cuda()

    def forward(self, x):
        x = x.long()
        x = self.net(x, output_all_encoded_layers=False).last_hidden_state
        x = F.dropout(x, self.alpha, training=self.training)
        x = torch.max(x, dim=1)[0]
        x = self.MLP(x)
        return torch.sigmoid(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

通过param.requires_grad = False锁定前8层的参数,不进行反向更新,微调后四层。进一步地调参可能使得效果更好,具体就不做啦。

实验结果

基于Bert(最大程度设置为512)的召回结果达到38.74%,而基于XLNet(最大长度设置为1024)的召回率达到了47.30%。不过,后者的训练时间几乎是前者的10倍。这对于只有几百条的测试文本几乎是不可接受的。
在这里插入图片描述
具体的代码很简单,连接如下:
https://github.com/songruiecho/BertMulti-LabelTextClassification
其中也包含了Bert的简单实现。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/677868
推荐阅读
相关标签
  

闽ICP备14008679号