当前位置:   article > 正文

两种方法快速实现基于bert预训练模型的分类任务,kashgari和keras_bert_闫广庆 nlp

闫广庆 nlp
​#! -*- coding:utf-8 -*-import codecsimport osimport numpy as npimport pandas as pdfrom keras_bert import load_trained_model_from_checkpoint, Tokenizerimport picklefrom tensorflow import kerastf_board_callback = keras.callbacks.TensorBoard(log_dir='./tf_dir', update_freq=1000)max_len = 100config_path = '/input1/BERT/chinese_L-12_H-768_A-12/bert_config.json'checkpoint_path = '/input1/BERT/chinese_L-12_H-768_A-12/bert_model.ckpt'dict_path = '/input1/BERT/chinese_L-12_H-768_A-12/vocab.txt'token_dict = {}def read_message():    if not os.path.exists("sets/x_items_without000.pkl"):        x_items = []        train_y = []        user_message = pd.read_csv("/input0/table1_user",                                   sep="\t")        jd_message = pd.read_csv("/input0/table2_jd",                                 sep="\t")        match_message = pd.read_csv("/input0/table3_action",                                    sep="\t")        user_message_index = {}        for i in user_message.values.tolist():            user_message_index[i[0]] = i        jd_message_index = {}        for i in jd_message.values.tolist():            jd_message_index[i[0]] = i        for i in match_message.values.tolist():            x_item = []            if i[0] in user_message_index.keys():                x_item = list(str(user_message_index[i[0]]))            if i[1] in jd_message_index.keys():                x_item.extend(list(str(jd_message_index[i[1]])))            y_label = str(i[2]) + str(i[3]) + str(i[4])            if y_label != "000":                x_items.append(x_item)                train_y.append(y_label)        with open('sets/x_items_without000.pkl', 'wb') as f:            pickle.dump(x_items, f)        with open('sets/train_y_without000.pkl', 'wb') as f:            pickle.dump(train_y, f)    else:        with open('sets/x_items_without000.pkl', 'rb') as f:            x_items = pickle.load(f)        with open('sets/train_y_without000.pkl', 'rb') as f:            train_y = pickle.load(f)    return x_items, train_ywith codecs.open(dict_path, 'r', 'utf8') as reader:    for line in reader:        token = line.strip()        token_dict[token] = len(token_dict)class OurTokenizer(Tokenizer):    def _tokenize(self, text):        R = []        for c in text:            if c in self._token_dict:                R.append(c)            elif self._is_space(c):                R.append('[unused1]')  # space类用未经训练的[unused1]表示            else:                R.append('[UNK]')  # 剩余的字符是[UNK]        return Rtokenizer = OurTokenizer(token_dict)# neg = pd.read_excel('neg.xls', header=None)# pos = pd.read_excel('pos.xls', header=None)neg,pos =read_message()data = []for d in neg[0]:    data.append((d, 0))for d in pos[0]:    data.append((d, 1))# 按照9:1的比例划分训练集和验证集random_order = list(range(len(data)))np.random.shuffle(random_order)train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]def seq_padding(X, padding=0):    L = [len(x) for x in X]    ML = max(L)    return np.array([        np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X    ])class data_generator:    def __init__(self, data, batch_size=32):        self.data = data        self.batch_size = batch_size        self.steps = len(self.data) // self.batch_size        if len(self.data) % self.batch_size != 0:            self.steps += 1    def __len__(self):        return self.steps    def __iter__(self):        while True:            idxs = list(range(len(self.data)))            np.random.shuffle(idxs)            X1, X2, Y = [], [], []            for i in idxs:                d = self.data[i]                text = d[0][:max_len]                x1, x2 = tokenizer.encode(first=text)                y = d[1]                X1.append(x1)                X2.append(x2)                Y.append([y])                if len(X1) == self.batch_size or i == idxs[-1]:                    X1 = seq_padding(X1)                    X2 = seq_padding(X2)                    Y = seq_padding(Y)                    yield [X1, X2], Y                    [X1, X2, Y] = [], [], []from keras.layers import *from keras.models import Modelfrom keras.optimizers import Adambert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)for l in bert_model.layers:    l.trainable = Truex1_in = Input(shape=(None,))x2_in = Input(shape=(None,))x = bert_model([x1_in, x2_in])x = Lambda(lambda x: x[:, 0])(x)p = Dense(1, activation='sigmoid')(x)model = Model([x1_in, x2_in], p)model.compile(    loss='binary_crossentropy',    optimizer=Adam(1e-5),  # 用足够小的学习率    metrics=['accuracy'])model.summary()train_D = data_generator(train_data)valid_D = data_generator(valid_data)model.fit_generator(    train_D.__iter__(),    steps_per_epoch=len(train_D),    epochs=5,    validation_data=valid_D.__iter__(),    validation_steps=len(valid_D),    callbacks=[tf_board_callback])

基于keras-bert和bert中文预训练语言模型的分类任务,如果你有分类任务可以替换83行的方法成自己的方法,代码已经测试可以跑通。

import osimport pickleimport kashgariimport pandas as pd# 读取文件数据 返回 训练数据 以及标签from kashgari.embeddings import BERTEmbeddingfrom kashgari.tasks.classification import CNNModelfrom tensorflow import kerastf_board_callback = keras.callbacks.TensorBoard(log_dir='./tf_dir', update_freq=1000)def read_message():    if not os.path.exists("sets/x_items_without000.pkl"):        x_items = []        train_y = []        user_message = pd.read_csv("../data/zhaopin_round1_train_20190716/zhaopin_round1_train_20190716/table1_user",                                   sep="\t")        jd_message = pd.read_csv("../data/zhaopin_round1_train_20190716/zhaopin_round1_train_20190716/table2_jd",                                 sep="\t")        match_message = pd.read_csv("../data/zhaopin_round1_train_20190716/zhaopin_round1_train_20190716/table3_action",                                    sep="\t")        user_message_index = {}        for i in user_message.values.tolist():            user_message_index[i[0]] = i        jd_message_index = {}        for i in jd_message.values.tolist():            jd_message_index[i[0]] = i        for i in match_message.values.tolist():            x_item = []            if i[0] in user_message_index.keys():                x_item = list(str(user_message_index[i[0]]))            if i[1] in jd_message_index.keys():                x_item.extend(list(str(jd_message_index[i[1]])))            y_label = str(i[2]) + str(i[3]) + str(i[4])            if y_label != "000":                x_items.append(x_item)                train_y.append(y_label)        with open('sets/x_items_without000.pkl', 'wb') as f:            pickle.dump(x_items, f)        with open('sets/train_y_without000.pkl', 'wb') as f:            pickle.dump(train_y, f)    else:        with open('sets/x_items_without000.pkl', 'rb') as f:            x_items = pickle.load(f)        with open('sets/train_y_without000.pkl', 'rb') as f:            train_y = pickle.load(f)    return x_items, train_y    # 训练模型def train():    x_xiyao, xiyao_y = read_message()    embed = BERTEmbedding("chinese_L-12_H-768_A-12",                          task=kashgari.CLASSIFICATION,                          sequence_length=64)    # 获取bert字向量    model = CNNModel(embed)    # 输入模型训练数据 标签 步数    model.fit(x_xiyao,              xiyao_y,              epochs=8,              batch_size=16,              callbacks=[tf_board_callback]              )    # 保存模型    model.save("/output/CNN_classfition_4-model")if __name__ == '__main__':    train()

如果你觉得keras-bert比较麻烦你可以试着学习一下kashgari。今天时间比较紧急所以我就直接写代码了,所有的数据来源都是阿里天池大赛,人岗匹配比赛,不传播数据集只传播比赛。

详细的kashgari文档在下面。

https://kashgari-zh.bmio.net/tutorial/text-classification/#_3

我是北京妙医佳健康科技集团妙云事业部闫广庆,专注医疗nlp。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/959010
推荐阅读
相关标签
  

闽ICP备14008679号