当前位置:   article > 正文

TextCNN网络详解(pytorch实现文本分类)_textcnn python 文本分类

textcnn python 文本分类

TextCNN的由来

CNN原本是被用在图像上面,后来在2014年一篇论文中被提出来用在自然语言处理任务上,在文本分类任务表现还不错,利用多个不同size的kernel来提取句子中的关键信息(类似于多窗口大小的n-gram),从而能够更好地捕捉局部相关性。下面来介绍一下模型结构。

TextCNN的模型结构:

主要有输入层,卷积层,池化层,全连接层

在这里插入图片描述

输入层

既然我们要使用卷积,我们就必须构造出一个n*m的矩阵,这个矩阵与卷积核做卷积运算。我们的原始数据是文本,怎么把它转化为矩阵的形式呢,我们可以采用one-hot编码或者是word-embedding编码,在次之前我们首先需要对文本进行分词。最常见的分词工具便是jieba分词。一段文本分词之后便得到n个词,记one-hot编码和word-embedding的特征维度为m,则我们构造的矩阵就为[n,m]。

每个词向量可以是预先在其他语料库中训练好的,也可以作为未知的参数由网络训练得到。这两种方法各有优势,预先训练的词嵌入可以利用其他语料库得到更多的先验知识,而由当前网络训练的词向量能够更好地抓住与当前任务相关联的特征。因此,图中的输入层实际采用了双通道的形式,即有两个 n*k 的输入矩阵,其中一个用预训练好的词嵌入表达,并且在训练过程中不再发生变化;另外一个也由同样的方式初始化,但是会作为参数,随着网络的训练过程发生改变。

卷积层和池化层

需要注意到,TextCNN中的卷积核与CV里面的卷积核不一样,cv里面的卷积核大多都是正方形的,比如最常见的3*3的卷积核,然后卷积核在整张image上沿高和宽按步长移动进行卷积操作。与CV中不同的是,在NLP中输入层的"image"是一个由词向量(word-embedding)拼成的词矩阵,且卷积核的宽和该词矩阵的宽相同,该宽度即为词向量大小,且卷积核只会在高度方向移动。因此,每次卷积核滑动过的位置都是完整的单词,不会将几个单词的一部分"vector"进行卷积,词矩阵的行表示离散的符号(也就是单词),这就保证了word作为语言中最小粒度的合理性(当然,如果研究的粒度是character-level而不是word-level,需要另外的方式处理)。

由于卷积核和word-embedding的宽度一致,一个卷积核与一个sentence做卷积运算,卷积后得到的结果是一个向量的形式,其shape=(sentence_len - filter_window_size + 1, 1),那么,在经过max-pooling操作后得到的就是一个标量。我们会使用多个filter_window_size(原因是,这样不同的kernel可以获取不同范围内词的关系,获得的是纵向的差异信息,即类似于n-gram,也就是在一个句子中不同范围的词出现会带来什么信息。比如可以使用2,3,4个词数分别作为卷积核的大小),每个filter_window_size又有num_filters个卷积核(原因是卷积神经网络学习的是卷积核中的参数,每个filter都有自己的关注点,这样多个卷积核就能学习到多个不同的信息。使用多个相同size的filter是为了从同一个窗口学习相互之间互补的特征。 比如可以设置size为3的filter有64个卷积核)。一个卷积核经过卷积操作只能得到一个scalar,将相同filter_window_size卷积出来的num_filter个scalar组合在一起,组成这个filter_window_size下的feature_vector。最后再将所有filter_window_size下的feature_vector也组合成一个single vector,作为最后一层softmax的输入。具体的情况可以看下面的图。

如下图所示:卷积核的数量为6个,其中filter_window_size有2,3,4,而每个size都有num_filters(2)个卷积核,每个卷积核与sentence做卷积运算得到一个vector(向量),然后max-pooling之后又得到一个scale(标量),之后便经过全连接层。

在这里插入图片描述

网络之后的结构就和具体的任务相关了,如果是文本分类的话,后面就接一个全连接层,并使用Softmax激活函数输出每个类别的概率。

TextCNN代码实现

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

class Config(object):
    def __init__(self):
        self.model_name = 'TextCNN'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备
        self.dropout = 0.5                                              # 随机失活
        self.num_classes = 10                                           # 类别数
        self.n_vocab = 10000                                            # 词表大小,在运行时赋值
        self.num_epochs = 20                                            # epoch数
        self.batch_size = 128                                           # mini-batch大小
        self.pad_size = 32                                              # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3                                       # 学习率
        self.embed = 300                                                # 字向量维度
        self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸
        self.num_filters = 256                                          # 卷积核数量(channels数)


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, x):
        out = self.embedding(x[0])
        out = out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

config=Config()
print(config)
model = Model(config)  
print(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/905811
推荐阅读
相关标签
  

闽ICP备14008679号