赞
踩
六年的大学生涯结束了,目前在搜索推荐岗位上继续进阶,近期正好在做类目预测多标签分类的项目,因此把相关的模型记录总结一下,便于后续查阅总结
一、理论篇:
在我们的场景中,文本数据量比较大,因此直接采用深度学习模型来预测文本类目的多标签,而TextCNN向来以速度快,准确率高著称。 TextCNN的核心思想是抓取文本的局部特征:通过不同的卷积核尺寸(确切的说是卷积核高度)来提取文本的N-gram信息,然后通过最大池化操作来突出各个卷积操作提取的最关键信息(颇有一番Attention的味道),拼接后通过全连接层对特征进行组合,最后通过交叉熵损失函数来训练模型。
接下来的几层就跟具体的任务相关了,一般都会拼接特征,在通过全连接层自由组合提取出来的特征实现分类。在损失函数上,二分类和多标签分类可以采用基于Sigmoid函数的交叉熵损失函数binary_crossentropy,多分类任务可以采用基于Softmax的多类别交叉熵损失函数(categorical_crossentropy)。
二、代码部分:
def textcnn(hyper_parameters): input = Input(shape=(hyper_parameters.max_len,)) if hyper_parameters.embedding_matrix is None: embedding = Embedding(input_dim=hyper_parameters.vocab + 1, output_dim=hyper_parameters.emd_dim, input_length=hyper_parameters.MAX_LEN, trainable=True)(input) else: # 使用预训练矩阵初始化Embedding embedding = Embedding(input_dim=hyper_parameters.vocab + 1, output_dim=hyper_parameters.emd_dim, weights=[hyper_parameters.embedding_matrix], input_length=hyper_parameters.MAX_LEN, trainable=False)(input) convs = [] for kernel_size in hyper_parameters.kernel_size: conv = Conv1D(hyper_parameters.conv_code, kernel_size, activation=hyper_parameters.relu)(embedding) pool = MaxPooling1D()(conv) convs.append(pool) concat = Concatenate()(convs) flattern = Flatten()(concat) dropout = Dropout(hyper_parameters.dropout)(flattern) output = Dense(hyper_parameters.classes, activation=hyper_parameters.sigmoid)(dropout) model = Model(input, output) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) return model
在Embedding部分,如果有条件可以使用自己预训练的文本信息来初始化Embedding矩阵,效果可能会比随机初始化Embedding提升一点。
三、几点思考:
1.TextCNN能用于文本分类的主要原因是什么?
除了预训练文本外,TextCNN通过利用不同的卷积核尺寸并行提取文本的信息(类似N-gram),并通过最大池化来突出最重要的关键词来实现分类。
2.TextCNN的缺点:
2.1. TextCNN的卷积和池化操作会丢失文本序列中的词汇顺序和位置信息等内容,但也可利用这一点来增强文本,例如白色旅游鞋,可以添加旅游鞋白色数据等,分词后白色和旅游鞋位置就可以互换来丰富语料 。
2.2. 在长文本使用TextCNN效果可能没有在短文本中效果好(具体还需要实践确认),原因同上,因此可以尝试使用Top k池化替代最大池化提取更多的文本信息。
参考文献:
1. https://arxiv.org/pdf/1408.5882.pdf2.https://zhuanlan.zhihu.com/p/77634533Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。