当前位置:   article > 正文

bert中的数据输入制作data_generate_数字能作为bert的输入吗

数字能作为bert的输入吗

‘’’
这的数据是这样的:是一个 txt 文件
每一行是一个样本: label@@@@@text
所以在后文加载数据的时候使用的5个@
‘’’

1.需要创建一个类data_generator,这个类继承DataGenerator类(bert4kreas.snippets)这个类主要是做数据生成的迭代器

from bert4keras.snippets import DataGenerator,sequence_padding  # 导入相关的包
from bert4keras.tokenizers import Tokenizer  # 导入分词器
from bert4keras.snippets import open as bert4keras_open
import json
import os
import random

# 这是一个数据生成器
class data_generator(DataGenerator):
	# data:数据
	# max_len:句子的最大长度
	# batch_size:小批量数据的条数
    def __init__(self, data, max_len, batch_size, vocab_path, buffer_size=None):
        self.max_len = max_len
        self.batch_size = batch_size
        self.vocab_path = vocab_path
        self.data = data
        self.buffer_size = buffer_size
        # 判断 self.data 中是否有 len 的方法
        if hasattr(self.data, "__len__"):
            self.steps = len(self.data) // self.batch_size
            if len(self.data) % self.batch_size != 0:
                self.steps += 1
        else:
            self.steps = None
        self.buffer_size = buffer_size or self.batch_size * 1000

    def __iter__(self, random=False):
    	# 建立一个分词器
        tokenizer = Tokenizer(self.vocab_path, do_lower_case=True)
        # 创建存 token,segment,label的列表
        batch_token_ids, batch_segment_ids, batch_label_ids = [], [], []
        # is_end 判断到没到最后一条数据(最后一条数据is_end=True,否则is_end=False)
        for is_end, (label, text) in self.sample(random):
        	# 将 text 文本进行编码得到 token_ids 与 segment_ids
            token_ids, segment_ids = tokenizer.encode(text, maxlen=self.max_len)
            # 加入创建好的列表
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_label_ids.append([label])
			# 判断 (是否是最后一条数据)or(是否达到了一个batch的数量)
            if is_end or len(batch_token_ids) == self.batch_size:
            	# 对于每个 token,segment 进行补全,根据 max_len 进行长度的统一
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_label_ids = sequence_padding(batch_label_ids)
                # 返回每个 batch 的数据
                yield [batch_token_ids, batch_segment_ids], batch_label_ids
                # 重新计数,下一个 batch
                batch_token_ids, batch_segment_ids, batch_label_ids = [], [], []
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

2. 创建 DateProcess() 类:

class DateProcess(object):
    def __init__(self, vocab_path, max_len, batch_size):
        self.vocab_path = vocab_path
        self.max_len = max_len
        self.batch_size = batch_size
	
    def get_label2id(self, train_data_path, model_output_path):
    	'''
        这个主要实现了将标签与 id 进行替换
        :param train_data_path: 数据的路径
        :param model_output_path: 模型输出的路径
        :return: 返回两个字典,一个是 label-> id 一个是 id-> label
        '''
        label_list = []
        with bert4keras_open(train_data_path, "r", encoding="utf-8") as f:
            for text in f:
                label_list.append(text.strip().split("@@@@@")[0])
		# 将label进行去重
        labels = sorted(set(label_list))
        id2label = {}
        label2id = {}
        for index, label in enumerate(labels):
            label2id[label] = index
            id2label[index] = label
		# 为了以后预测模型时 id 与 label 好对应,所以将 id2label 进行 json 格式的保存
        with bert4keras_open(os.path.join(model_output_path + "id2abel_new_new.json"), "w") as f:
            json.dump(id2label, f, ensure_ascii=False)

        return label2id, id2label

    def load_data(self, file_name, label2id):
    	'''
        这个函数用来加载数据,将样本存到一个列表中,每个样本是一个元组 [(label,text),(label,text),(label,text)..........(label,text)]
        这里就通过 label2id(字典)将文本的 label 转化为了 id
        :param file_name: 
        :param label2id: 
        :return: 返回一个列表
        '''
        data_list = []
        with bert4keras_open(file_name, "r", encoding="utf-8") as f:
            for line in f:
                label = label2id[line.strip().split("@@@@@")[0]]
                text = line.strip().split("@@@@@")[1]
                data_list.append((label, text))

        return data_list

    def generate_data(self, train_data_path, model_output_path):
    	'''
        这个函数就是将 加载数据,label->id,生成 训练集,测试集,验证机生成器的一个函数
        :param train_data_path: 
        :param model_output_path: 
        :return: 返回5个参数,label2id, id2label, train_data_generate, vail_data_generate, test_data_generate
        '''
        # 生成 label-> id 的字典
        label2id, id2label = self.get_label2id(train_data_path, model_output_path)
        # 生成 数据的列表
        data_list = self.load_data(train_data_path, label2id)
        length = len(data_list)
        # 进行数据的打乱
        random.shuffle(data_list)
        # 划分训练集,验证集,测试集
        train_data = data_list[:int(0.8 * length)]
        vail_data = data_list[int(0.8 * length):int(0.9 * length)]
        test_data = data_list[int(0.9 * length):]
		
		# 创建三个数据生成器
        train_data_generate = data_generator(train_data, self.max_len, self.batch_size, self.vocab_path)
        vail_data_generate = data_generator(vail_data, self.max_len, self.batch_size, self.vocab_path)
        test_data_generate = data_generator(test_data, self.max_len, self.batch_size, self.vocab_path)

        return label2id, id2label, train_data_generate, vail_data_generate, test_data_generate
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

3.测试

if __name__ == '__main__':
    data_process = DateProcess("../data/vocab.txt", max_len=128, batch_size=32)
    _, _, train, vail, test = data_process.generate_data("../data/result_sample.txt", "../data/output")
    for token_and_segment,label in test:
        print("*"*1000)
        print(token_and_segment[0])
        print()
        print(token_and_segment[1])
        print()
        print(label[:, 0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

测试结果:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/106865
推荐阅读
相关标签
  

闽ICP备14008679号