当前位置:   article > 正文

BERT(5)---实战[BERT+CNN文本分类]_bert结合cnn文本分类

bert结合cnn文本分类

1. 模型介绍

获取bert模型最后的token-level形式的输出(get_sequence_output)也就是transformer模型最后一层的输出,将此作为embedding_inputs,作为卷积的输入;用三个不同的卷积核进行卷积和池化,最后将三个结果concat,三种不同的卷积核大小为[2, 3, 4];每种卷积核的数量都为128个;卷积流程类似下图所示,但实际上有所区别,下图表示的是使用大小为[2, 3, 4]三种卷积核,每种卷积核数量为2。
在这里插入图片描述
卷积后的输出形状大小为[batchsize, num_filters*len(filter_size)]=[batchsize, 128*3];将该输出再连接两个全连接层,后一个全连接层用作分类。

2. 数据处理及训练

因为数据先放入BERT模型中, 再将BERT的输出接入CNN, 因此我们需要将数据处理成BERT模型能接收的格式,所以此处数据处理大多参考BERT源码中的数据处理方式

首先依然类似BERT fine-tuning章节中所述自定义一个类来处理原始数据,在该类中主要实现以下功能:加载训练、测试、验证数据, 设置分类标签,具体实现如下:

class TextProcessor(object):
    """按照InputExample类形式载入对应的数据集"""

    """load train examples"""
    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_file(os.path.join(data_dir, "train.tsv")), "train")

    """load dev examples"""
    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_file(os.path.join(data_dir, "dev.tsv")), "dev")

    """load test examples"""
    def get_test_examples(self, data_dir):
          return self._create_examples(
              self._read_file(os.path.join(data_dir, "test.tsv")), "test")

    """set labels"""
    def get_labels(self):
        return ['sport', 'military', 'aerospace', 'car', 'business', 'chemistry', 'construction', 'culture', 'electric', 'finance', 'geology', 'it', 'law', 'mechanical', 'medicine', 'tourism']

    """read file"""
    def _read_file(self, input_file):
        with codecs.open(input_file, "r",encoding='utf-8') as f:
            lines = []
            for line in f.readlines():
                try:
                    line=line.split('\t')
                    assert len(line)==2
                    lines.append(line)
                except:
                    pass
            np.random.shuffle(lines)
            return lines

    """create examples for the data set """
    def _create_examples(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
          guid = "%s-%s" % (set_type, i)
          text_a = tokenization.convert_to_unicode(line[1])
          label = tokenization.convert_to_unicode(line[0])
          examples.append(
              InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

通过上面的代码我们得到example仅仅是BERT中的text_a, text_b, label这还不能作为BERT模型的输入,我们需要将其转成[CLS] + 分词好的text_a + [SEP] + 分词好的text_b的形式送入BERT模型中。因此定义convert_examples_to_features函数将所有的InputExamples样本数据转化成模型要输入的token形式,最后输出bert模型需要的四个变量input_ids, input_mask, segment_ids, label_ids

def convert_examples_to_features(examples,label_list, max_seq_length,tokenizer):
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    input_data=[]
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]

        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        input_mask = [1] * len(input_ids)

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label_map[example.label]
        if ex_index < 3:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (example.guid))
            tf.logging.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens]))
            tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            tf.logging.info("label: %s (id = %d)" % (example.label, label_id))

        features = collections.OrderedDict()
        features["input_ids"] = input_ids
        features["input_mask"] = input_mask
        features["segment_ids"] = segment_ids
        features["label_ids"] =label_id
        input_data.append(features)

    return input_data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

然后加载保存好的BERT预训练模型,然后开始训练,当然我们训练时候是每次传入一个batch的数据,因此我们还需要将刚才处理好的数据打包,每次送给模型一个batch的数据,训练过程如下:

    for epoch in range(config.num_epochs):
        batch_train = batch_iter(trian_data,config.batch_size)
        start = time.time()
        tf.logging.info('Epoch:%d'%(epoch + 1))
        for batch_ids,batch_mask,batch_segment,batch_label in batch_train:
            feed_dict = feed_data(batch_ids,batch_mask,batch_segment,batch_label, config.keep_prob)
            _, global_step, train_summaries, train_loss, train_accuracy = session.run([model.optim, model.global_step,
                                                                                    merged_summary, model.loss,
                                                                                    model.acc], feed_dict=feed_dict)
            tf.logging.info('step:%d'%(global_step))
            if global_step % config.print_per_batch == 0:
                end = time.time()
                val_loss,val_accuracy=evaluate(session,dev_data)
                merged_acc=(train_accuracy+val_accuracy)/2
                if merged_acc > best_acc:
                    saver.save(session, save_path)
                    best_acc = merged_acc
                    last_improved=global_step
                    improved_str = '*'
                else:
                    improved_str = ''
                tf.logging.info("step: {},train loss: {:.3f}, train accuracy: {:.3f}, val loss: {:.3f}, val accuracy: {:.3f},training speed: {:.3f}sec/batch {}".format(
                        global_step, train_loss, train_accuracy, val_loss, val_accuracy,(end - start) / config.print_per_batch,improved_str))
                start = time.time()

            if global_step - last_improved > config.require_improvement:
                tf.logging.info("No optimization over 1500 steps, stop training")
                flag = True
                break
        if flag:
            break
        config.lr *= config.lr_decay
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

3.搭建服务

使用tornado搭建服务,如下所示

class ClassifierHandler(tornado.web.RequestHandler):
    def post(self):
        error_logger = Logger(loggername="bert_cnn_error_" + time.strftime("%Y_%m_%d", time.localtime()), logpath=LOGPATH + "bert_cnn_error_" + time.strftime("%Y_%m_%d", time.localtime()) + ".log").log()
        try:
            segment = self.get_argument("segment")
            language = self.get_argument("lang")
            sentence_data = json.loads(segment)
            results = predict.run(sentence_data, language)
            self.write(json.dumps(results))
        except Exception as e:
            self.write(repr(e))
            error_logger.error("error message:%s" % repr(e))
            error_logger.error("error position:%s" % traceback.format_exc())

    def write_error(self, status_code, **kwargs):
        self.write("errors: %d." % status_code)

if __name__ == "__main__":
    tornado.options.parse_command_line()
    app = tornado.web.Application(handlers=[(r"/classifier", ClassifierHandler)])
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

4.其他

数据分为16个类别['sport', 'military', 'aerospace', 'car', 'business', 'chemistry', 'construction', 'culture', 'electric', 'finance', 'geology', 'it', 'law', 'mechanical', 'medicine', 'tourism'],总数据约一个亿左右

中文BERT模型:12-layer, 768-hidden, 12-heads, 110M parameters
英文BERT模型:12-layer, 768-hidden, 12-heads, 110M parameters

运行程序(主程序) text_run.py

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/378972
推荐阅读
相关标签
  

闽ICP备14008679号