当前位置:   article > 正文

bert-textcnn实现多标签文本分类(基于keras+keras-bert构建)_bert-textcnn实现多标签文本分类(基于keras+keras-bert构建)_textcn

bert-textcnn实现多标签文本分类(基于keras+keras-bert构建)_textcnn多标签分类_

基于keras+keras-bert构建bert-textcnn模型实现多标签文本分类

跑别人的代码,最痛苦的莫不在于环境有错误、代码含义不懂。自己从头到尾尝试了一遍,过程很艰难,为了方便同样在学习的朋友,在这里,我会在项目文件中提供详细的requirements,保证你能一次性跑成功。此外,每个部分我都会尽可能的添加详细的注释,使得读者能够知道每一步的意义和结果。

前言

  1. 什么是bert?

    BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

    ​ BERT是2018年Google AI Language提出的一种预训练语言模型。BERT通过联合调节所有层的左右上下文来预训练来自未标记文本的深度双向表示。因此,预训练的 BERT 模型可以通过一个额外的输出层进行微调,从而为各种任务(例如问答和语言推理)创建最先进的模型,而无需对特定于任务的架构进行大量修改。BERT 在概念上很简单,在经验上很强大。它在 11 个自然语言处理任务上获得了新的 state-of-the-art 结果,是NLP发展的里程碑。

  2. 什么是textcnn?

    Convolutional Neural Network for Sentence Classification

    ​ TextCNN是Yoon Kim在2014年将CNN网络应用于句子级的文本分类所提出的结构。如下图所示,TextCNN利用多个不同的kernel size来提取句子中的关键信息,不同的kernel size的结果进行拼接进行pooling操作,以更好的获取文本的局部特征。

    image-20220817150250202
  3. 本项目中,我首先利用BERT输出句子的嵌入表示,然后将嵌入表示结果输入构造好的多尺寸TextCNN中进行特征提取,并用作最后的分类。

数据介绍

项目数据用的是2020语言与智能技术竞赛:事件抽取任务,数据我会直接放在项目的[data]文件夹中。

  1. 数据的基本结构:文本对应标签与文本之间用空格隔开,多个标签之间用|隔开。

    组织关系-裁员 雀巢裁员4000人:时代抛弃你时,连招呼都不会打!
    组织关系-裁员 美国“未来为”子公司大幅度裁员,这是为什么呢?任正非正式回应
    组织关系-裁员 这一全球巨头“凉凉”“捅刀”华为后裁员5000现市值缩水800亿
    组织关系-裁员 被证实将再裁员1800人AT&T在为落后的经营模式买单
    组织关系-裁员 又一网约车巨头倒下:三个月裁员835名员工,滴滴又该何去何从
    组织关系-裁员 8月20日消息,据腾讯新闻《一线》报道,知情人士表示,为了控制成本支出,蔚来计划将美国分公司的人员规模除自动驾驶业务相关人员外,减少至200人左右。截至美国时间8月16日,蔚来位于美国硅谷的分公司已裁减100名员工。
    司法行为-起诉|组织关系-裁员 最近,一位前便利蜂员工就因公司违规裁员,将便利蜂所在的公司虫极科技(北京)有限公司告上法庭。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
  2. 数据集的基本信息

模型搭建

  1. 构建包含多个kernel size的TextCNN网络。

    def textcnn(inputs):
        # 选用3、4、5三个卷积核进行特征提取,最后拼接后输出用于分类。
        kernel_size = [3, 4, 5]
        cnn_features = []
        for size in kernel_size:
            cnn = keras.layers.Conv1D(filters=256, kernel_size=size)(inputs)  # shape=[batch_size,maxlen-2,256]
            cnn = keras.layers.GlobalMaxPooling1D()(cnn)  # shape=[batch_size,256]
            cnn_features.append(cnn)
        # 对kernel_size=3、4、5时提取的特征进行拼接
        output = keras.layers.concatenate(cnn_features, axis=-1)  # [batch_size,256*3]
        # 返回textcnn提取的特征结果
        return output
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
  2. 构建bert_textcnn模型。

  • 首先利用keras-bert加载预训练好的bert,这里用的bert是哈工大训练的chinese_bert_wwm_L-12_H-768_A-12。

  • 取出bert的输出中的[cls]向量,[cls]可以直接用于分类,也可以与其它网络的输出拼接。

  • 取出bert输出中关于输入句子的表示(word_embedding),bert在输入时在句子的头和尾分类添加了一个[CLS]、[SEP],可以选择去除这两个标志。

  • 将word_embedding输入构造好的多kernel size的TextCNN网络,获得经由TextCNN获得特征(cnn_features)。

  • 将[cls]与cnn_features进行拼接后用于分类。

  • 根据输入和输出封装模型,并进行必要参数的配置。

  • 模型最后的结果如下所示(bert仅展示最后一层):image-20220817152647722

  • 详细代码如下:

    def build_bert_textcnn_model(config_path, checkpoint_path, class_nums):
        """
        :param config_path: bert_config.json所在位置。
        :param checkpoint_path: bert_model.ckpt所在位置。
        :param class_nums: 最终模型的输出的维度(分类的类别)。
        :return:返回搭建好的模型。
        """
        # 加载预训练好的bert
        bert = load_trained_model_from_checkpoint(
            config_file=config_path,
            checkpoint_file=checkpoint_path,
            seq_len=None
        )
    
        # 取出[cls],可以直接用于分类,也可以与其它网络的输出拼接。
        cls_features = keras.layers.Lambda(
            lambda x: x[:, 0],
            name='cls'
        )(bert.output)  # shape=[batch_size,768]
    
        # 去除第一个[cls]和最后一个[sep],得到输入句子的embedding,用作textcnn的输入。
        word_embedding = keras.layers.Lambda(
            lambda x: x[:, 1:-1],
            name='word_embedding'
        )(bert.output)  # shape=[batch_size,maxlen-2,768]
    
        # 将句子的embedding,输入textcnn,得到经由textcnn提取的特征。
        cnn_features = textcnn(word_embedding)  # shape=[batch_size,cnn_output_dim]
    
        # 将cls特征与textcnn特征进行拼接。
        all_features = keras.layers.concatenate([cls_features, cnn_features], axis=-1)  # shape=[batch_size,cnn_output_dim+768]
    
        # 应用dropout缓解过拟合的现象,rate一般在0.2-0.5。
        all_features = keras.layers.Dropout(0.2)(all_features)  # shape=[batch_size,cnn_output_dim+768]
    
        # 降维
        dense = keras.layers.Dense(units=256, activation='relu')(all_features)  # shape=[batch_size,256]
    
        # 输出结果
        output = keras.layers.Dense(
            units=class_nums,
            activation='sigmoid'
        )(dense)  # shape=[batch_size,class_nums]
    
        # 根据输入和输出构建构建模型
        model = keras.models.Model(bert.input, output, name='bert-textcnn')
    
        model.compile(
            loss='binary_crossentropy',
            optimizer=keras.optimizers.Adam(config.learning_rate),
            metrics=['accuracy']
        )
        return model
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

模型训练

模型的训练大致有以下4步:

  • 加载训练集、测试集的数据。
  • 对训练集的文本、标签;测试集的文本、标签分别进行编码。
  • 初始化模型,将训练集、测试集的编码结果送入模型开始训练。
  • 绘制训练过程中的训练与验证的loss与acc图像(可选)。
  1. 加载训练集

    # 用以加载数据
    def load_data(txt_file_path):
        text_list = []
        label_list = []
        with open(txt_file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                line = line.strip().split()
                label_list.append(line[0].split('|'))
                text_list.append(line[1])
        return text_list, label_list
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  2. 对文本编码

    ​ 对文本编码需要弄清楚,输入给bert的是什么,bert的输入需要token_id与segment_id,是tokenizer操作后的返回值。

    # 加载bert字典,构造分词器。
    token_dict = load_vocabulary(config.bert_dict_path)
    tokenizer = Tokenizer(token_dict)
    # 对文本编码
    def encoding_text(content_list):
        token_ids = []
        segment_ids = []
        for line in tqdm(content_list):
        	# max_len是用于保证所有的输入长度一致,长度不足时会补0,长度超过时会截断。
            token_id, segment_id = tokenizer.encode(first=line, max_len=config.max_len) 
            token_ids.append(token_id)
            segment_ids.append(segment_id)
        # 输入给模型的数据不能是list,这里需要做一下转换编程array。
        encoding_res = [np.array(token_ids), np.array(segment_ids)]
        return encoding_res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
  3. 主函数

    if __name__ == "__main__":
        # 读取训练集与测试集
        train_content_x, train_label_y = load_data(config.train_dataset_path)
        test_content_x, test_label_y = load_data(config.test_dataset_path)
    
        # 打乱训练集的数据
        index = [i for i in range(len(train_content_x))]
        random.shuffle(index)  # 打乱索引表
        # 按打乱后的索引,重新组织训练集
        train_content_x = [train_content_x[i] for i in index]
        train_label_y = [train_label_y[i] for i in index]
    
        # 对训练集与测试集的文本编码
        train_x = encoding_text(train_content_x)
        test_x = encoding_text(test_content_x)
    
        # 对标签集编码(调用sklearn的多标签编码器)
        mlb = MultiLabelBinarizer()
        mlb.fit(train_label_y)
        # 保存此时的mlb,后面在预测时评估时需要加载标签集。
        pickle.dump(mlb, open('./data/mlb.pkl', 'wb'))
        # 分别对训练集和测试集的标签进行编码,并转换为array。
        train_y = np.array(mlb.transform(train_label_y))
        test_y = np.array(mlb.transform(test_label_y))
    	# 初始化模型,并输出模型的结果
        model = build_bert_textcnn_model(config.bert_config_path, config.bert_checkpoint_path, len(mlb.classes_))
        model.summary()
        # 开始模型的训练,并保存训练的历史数据(loss、accuracy)用以最后绘图
        history = model.fit(train_x, train_y, validation_data=(test_x, test_y), batch_size=config.batch_size, epochs=config.epochs)
        # 保存模型为h5
        model.save("./model/bert_textcnn.h5")
    
        # 训练过程可视化
        # 绘制训练loss和验证loss的对比图
        plt.subplot(2, 1, 1)
        epochs = len(history.history['loss'])
        plt.plot(range(epochs), history.history['loss'], label='loss')
        plt.plot(range(epochs), history.history['val_loss'], label='val_loss')
        plt.legend()
        # 绘制训练acc和验证acc的对比图
        plt.subplot(2, 1, 2)
        epochs = len(history.history['accuracy'])
        plt.plot(range(epochs), history.history['accuracy'], label='acc')
        plt.plot(range(epochs), history.history['val_accuracy'], label='val_acc')
        plt.legend()
        # 保存loss与acc对比图
        plt.savefig("./model/bert-textcnn-loss-acc.png")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

模型评估

模型评估大致有以下几步:

  • 加载评估集(测试集)。

  • 对评估集中数据逐条预测保存预测结果。

  • 计算accuracy,调用classification_report输出各个标签的详细评估结果,调用hamming_loss输出汉明损失。

  • 详细代码及注释如下:

    # 加载bert字典,构造分词器。
    token_dict = load_vocabulary(config.bert_dict_path)
    tokenizer = Tokenizer(token_dict)
    
    # 加载训练好的模型
    model = load_model('./model/bert_textcnn.h5', custom_objects=get_custom_objects())
    mlb = pickle.load(open('./data/mlb.pkl', 'rb'))
    
    
    def load_data(txt_file_path):
        text_list = []
        label_list = []
        with open(txt_file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                line = line.strip().split()
                label_list.append(line[0].split('|'))
                text_list.append(line[1])
        return text_list, label_list
    
    
    def predict_single_text(text):
        # 编码后得出句子给bert的输入
        token_id, segment_id = tokenizer.encode(first=text, max_len=config.max_len)
        # 得到预测结果
        prediction = model.predict([[token_id], [segment_id]])[0]
    	# 这里以阈值0.5进行标签的筛选,取出值大于0.5标签的索引
        indices = [i for i in range(len(prediction)) if prediction[i] > 0.5]
        # 将索引转换为最终的标签集
        lables = [mlb.classes_.tolist()[i] for i in indices]
        # 输出最后结果的编码,用以评估
        one_hot = np.where(prediction > 0.5, 1, 0)
        return one_hot, lables
    
    
    def evaluate():
        test_x, test_y = load_data(config.test_dataset_path)
        true_y_list = mlb.transform(test_y)
    
        pred_y_list = []
        pred_labels = []
        for text in tqdm(test_x):
            pred_y, label = predict_single_text(text)
            pred_y_list.append(pred_y)
            pred_labels.append(label)
    
        # 计算accuracy,一条数据的所有标签全部预测正确则1,否则为0。
        test_len = len(test_y)
        correct_count = 0
        for i in range(test_len):
            if test_y[i] == pred_labels[i]:
                correct_count += 1
        accuracy = correct_count / test_len
    
        print(classification_report(true_y_list, pred_y_list, target_names=mlb.classes_.tolist(), digits=4))
        print("accuracy:{}".format(accuracy))
        print("hamming_loss:{}".format(hamming_loss(true_y_list, pred_y_list)))
    
    
    if __name__ == "__main__":
        evaluate()
        
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
  • 评估结果下所示:

    labelprecisionrecallf1-scoresupport
    交往-会见1.00001.00001.000012
    交往-感谢1.00000.87500.93338
    交往-探班1.00000.90000.947410
    交往-点赞0.88890.72730.800011
    交往-道歉0.81820.94740.878019
    产品行为-上映0.96970.91430.941235
    产品行为-下架1.00001.00001.000024
    产品行为-发布0.94810.97330.9605150
    产品行为-召回1.00001.00001.000036
    产品行为-获奖1.00000.93750.967716
    人生-产子/女0.86670.86670.866715
    人生-出轨1.00000.50000.66674
    人生-分手1.00000.93330.965515
    人生-失联1.00000.92860.963014
    人生-婚礼1.00000.66670.80006
    人生-庆生1.00000.87500.933316
    人生-怀孕1.00000.50000.66678
    人生-死亡0.95100.91510.9327106
    人生-求婚1.00001.00001.00009
    人生-离婚0.93940.93940.939433
    人生-结婚0.96550.65120.777843
    人生-订婚1.00000.77780.87509
    司法行为-举报1.00001.00001.000012
    司法行为-入狱0.90001.00000.947418
    司法行为-开庭0.92310.85710.888914
    司法行为-拘捕0.97700.96590.971488
    司法行为-立案1.00001.00001.00009
    司法行为-约谈0.96971.00000.984632
    司法行为-罚款1.00000.89660.945529
    司法行为-起诉0.87501.00000.933321
    灾害/意外-地震1.00001.00001.000014
    灾害/意外-坍/垮塌1.00000.80000.888910
    灾害/意外-坠机1.00001.00001.000013
    灾害/意外-洪灾1.00000.71430.83337
    灾害/意外-爆炸1.00001.00001.00009
    灾害/意外-袭击0.80000.75000.774216
    灾害/意外-起火0.96431.00000.981827
    灾害/意外-车祸0.93940.88570.911835
    竞赛行为-夺冠0.82140.82140.821456
    竞赛行为-晋级0.84210.96970.901433
    竞赛行为-禁赛0.88240.93750.909116
    竞赛行为-胜负0.97220.98590.9790213
    竞赛行为-退役0.91671.00000.956511
    竞赛行为-退赛0.83330.83330.833318
    组织关系-停职0.84621.00000.916711
    组织关系-加盟0.92310.87800.900041
    组织关系-裁员0.94740.94740.947419
    组织关系-解散0.90000.90000.900010
    组织关系-解约0.80000.80000.80005
    组织关系-解雇1.00000.30770.470613
    组织关系-辞/离职0.92211.00000.959571
    组织关系-退出0.83330.90910.869622
    组织行为-开幕0.93940.96880.953832
    组织行为-游行1.00000.88890.94129
    组织行为-罢工1.00000.87500.93338
    组织行为-闭幕1.00000.77780.87509
    财经/交易-上市1.00000.85710.92317
    财经/交易-出售/收购1.00000.91670.956524
    财经/交易-加息1.00000.33330.50003
    财经/交易-涨价0.80000.80000.80005
    财经/交易-涨停1.00001.00001.000027
    财经/交易-融资1.00001.00001.000014
    财经/交易-跌停0.93331.00000.965514
    财经/交易-降价1.00000.66670.80009
       micro avg     0.9450    0.9234    0.9341      1657
       macro avg     0.9509    0.8780    0.9029      1657
    weighted avg     0.9476    0.9234    0.9309      1657
     samples avg     0.9302    0.9347    0.9265      1657
    
    accuracy:0.8344459279038718
    hamming_loss:0.002218342405258293
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

模型预测

  • 模型预测其实就是将evaluate中的部分操作单独出来,具体的代码如下所示

    # 加载bert字典,构造分词器。
    token_dict = load_vocabulary(config.bert_dict_path)
    tokenizer = Tokenizer(token_dict)
    
    # 加载训练好的模型
    model = load_model('./model/bert_textcnn.h5', custom_objects=get_custom_objects())
    mlb = pickle.load(open('./data/mlb.pkl', 'rb'))
    
    
    # 预测单个句子的标签
    def predict_single_text(text):
        token_id, segment_id = tokenizer.encode(first=text, max_len=config.max_len)
        prediction = model.predict([[token_id], [segment_id]])[0]
    
        indices = [i for i in range(len(prediction)) if prediction[i] > 0.5]
        lables = [mlb.classes_.tolist()[i] for i in indices]
        return "|".join(lables)
    
    
    if __name__ == "__main__":
        text = "美的置业:贵阳项目挡墙垮塌致8人遇难已责令全面停工"
        result = predict_single_text(text)
        print(result)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

项目结构、下载、使用方法

  1. 项目结构

    BERT-TEXTCNN-MULTI-LABEL-TEXT-CLASSFICATION
    │  bert_textcnn_model.py # 构建模型的文件
    │  config.py # 项目的相关配置及参数文件
    │  model_evaluate.py # 用于模型评估的文件
    │  model_predict.py # 用于模型预测的文件
    │  model_train.py # 用于模型训练的文件
    │  requirements.txt # 项目所需的环境依赖(python3.6下直接运行安装本文件里的所有依赖可以稳定运行)
    │     
    ├─chinese_bert_wwm_L-12_H-768_A-12 # 预训练的bert模型,使用时需要自行去下载后复制到项目中。
    │      bert_config.json
    │      bert_model.ckpt.data-00000-of-00001
    │      bert_model.ckpt.index
    │      bert_model.ckpt.meta
    │      vocab.txt
    │      
    ├─data # 数据集
    │      mlb.pkl # 训练时生成(项目中已移除)
    │      multi-classification-test.txt
    │      multi-classification-train.txt
    │      
    ├─model # 此文件夹需自行新建
    │      bert-textcnn-loss-acc.png # 训练时的loss-acc图像(运行model_train.py可得)
    │      bert_textcnn.h5 # 训练得到的模型(运行model_train.py可得)
    │      model.png # 模型结构图(运行bert_textcnn_model.py可得)
    └─
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
  2. 项目下载地址

    bert-textcnn-for-multi-label-text-classfication

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/342463
推荐阅读
相关标签
  

闽ICP备14008679号