赞
踩
之前我写过一篇文章,利用bert来生成token级向量(对于中文语料来说就是字级别向量),参考我的文章:《使用BERT模型生成token级向量》。但是这样做有一个致命的缺点就是字符序列长度最长为512(包含[cls]和[sep])。其实对于大多数语料来说已经够了,但是对于有些语料库中样本的字符序列长度都比较长的情况,这就有些不够用了,比如我做一个法院文书领域预测任务,里面的事实部分许多都大于1000字,我做TextCharCNN的时候定义的最大长度为1500(能够涵盖百分之95以上的样本)。
这个时候怎么办呢,我想到了一个办法,就是用句子序列来代表他们。比如一段事实有1500个字,如果按照句号划分,则有80个句子。那么每一个句子,我们可以利用bert得到句子向量,我们可以把一个句子中包含的字符的最大长度人为定义为128(实际上这样一个句子得到的结果的shape是(128, 768),可以参考我文章开头提到的那篇文章。我的做法是首先要用bert模型在我们的数据集任务上进行微调,然后用微调过的模型去生成这样一个结果,然后取出第0个分量,也就是说,一句话有128个字符,第0个字符是[cls]字符,我们就取第0个字符代表这句话的向量表示,这也就是为什么我在前面提到一定要在我们的任务中微调过模型再拿来用,要不然这个[cls]向量取出来并不好用!!!)BERT微调的参照我的文章:《使用BERT预训练模型+微调进行文本分类》
那么每一个句子得到了一个shape为(768,)的向量,这就是这个句子的embedding,然后一个样本设定有80个句子,如果超过80个句子,则取前80个,如果不到80个句子,则填充768维全0向量。最终生成的结果是:(N,80,768)。N代表样本数,80代表句子最大长度,768代表向量维度,然后可以用这个结果去做mean_pooling,或者做卷积之类的。
下面代码(注释比较清晰,就不解释了):
- # 配置文件
- # data_root是模型文件,可以用预训练的,也可以用在分类任务上微调过的模型
- data_root = '../chinese_wwm_ext_L-12_H-768_A-12/'
- bert_config_file = data_root + 'bert_config.json'
- bert_config = modeling.BertConfig.from_json_file(bert_config_file)
- # init_checkpoint = data_root + 'bert_model.ckpt'
- # 这样的话,就是使用在具体任务上微调过的模型来做词向量
- init_checkpoint = '../model/cnews_fine_tune/model.ckpt-18674'
- # init_checkpoint = '../model/legal_fine_tune/model.ckpt-4153'
- bert_vocab_file = data_root + 'vocab.txt'
-
- # 经过处理的输入文件路径
- file_input_x_c_train = '../data/cnews/train_x.txt'
- file_input_x_c_val = '../data/cnews/val_x.txt'
- file_input_x_c_test = '../data/cnews/test_x.txt'
-
- # embedding存放路径
- # emb_file_dir = '../data/legal_domain/emb_fine_tune.h5'
-
- # graph
- input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
- input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks')
- segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids')
-
- # 每个sample固定为80个句子
- SEQ_LEN = 80
- # 每个句子固定为128个token
- SENTENCE_LEN = 126
-
-
- def get_batch_data(x):
- """生成批次数据,一个batch一个batch地产生句子向量"""
- data_len = len(x)
-
- word_mask = [[1] * (SENTENCE_LEN + 2) for i in range(data_len)]
- word_segment_ids = [[0] * (SENTENCE_LEN + 2) for i in range(data_len)]
- return x, word_mask, word_segment_ids
-
-
- def read_input(file_dir):
- # 从文件中读到所有需要转化的句子
- # 这里需要做统一长度为510
- # input_list = []
- with open(file_dir, 'r', encoding='utf-8') as f:
- input_list = f.readlines()
-
- # input_list是输入list,每一个元素是一个str,代表输入文本
- # 现在需要转化成id_list
- word_id_list = []
- for query in input_list:
- tmp_word_id_list = []
- quert_str = ''.join(query.strip().split())
- sentences = re.split('。', quert_str)
- # 在这里截取掉大于seq_len个句子的样本,保留其前seq_len个句子
- if len(sentences) > SEQ_LEN:
- sentences = sentences[:SEQ_LEN]
- for sentence in sentences:
- split_tokens = token.tokenize(sentence)
- if len(split_tokens) > SENTENCE_LEN:
- split_tokens = split_tokens[:SENTENCE_LEN]
- else:
- whil
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。