赞
踩
‘’’
这的数据是这样的:是一个 txt 文件
每一行是一个样本: label@@@@@text
所以在后文加载数据的时候使用的5个@
‘’’
from bert4keras.snippets import DataGenerator,sequence_padding # 导入相关的包 from bert4keras.tokenizers import Tokenizer # 导入分词器 from bert4keras.snippets import open as bert4keras_open import json import os import random # 这是一个数据生成器 class data_generator(DataGenerator): # data:数据 # max_len:句子的最大长度 # batch_size:小批量数据的条数 def __init__(self, data, max_len, batch_size, vocab_path, buffer_size=None): self.max_len = max_len self.batch_size = batch_size self.vocab_path = vocab_path self.data = data self.buffer_size = buffer_size # 判断 self.data 中是否有 len 的方法 if hasattr(self.data, "__len__"): self.steps = len(self.data) // self.batch_size if len(self.data) % self.batch_size != 0: self.steps += 1 else: self.steps = None self.buffer_size = buffer_size or self.batch_size * 1000 def __iter__(self, random=False): # 建立一个分词器 tokenizer = Tokenizer(self.vocab_path, do_lower_case=True) # 创建存 token,segment,label的列表 batch_token_ids, batch_segment_ids, batch_label_ids = [], [], [] # is_end 判断到没到最后一条数据(最后一条数据is_end=True,否则is_end=False) for is_end, (label, text) in self.sample(random): # 将 text 文本进行编码得到 token_ids 与 segment_ids token_ids, segment_ids = tokenizer.encode(text, maxlen=self.max_len) # 加入创建好的列表 batch_token_ids.append(token_ids) batch_segment_ids.append(segment_ids) batch_label_ids.append([label]) # 判断 (是否是最后一条数据)or(是否达到了一个batch的数量) if is_end or len(batch_token_ids) == self.batch_size: # 对于每个 token,segment 进行补全,根据 max_len 进行长度的统一 batch_token_ids = sequence_padding(batch_token_ids) batch_segment_ids = sequence_padding(batch_segment_ids) batch_label_ids = sequence_padding(batch_label_ids) # 返回每个 batch 的数据 yield [batch_token_ids, batch_segment_ids], batch_label_ids # 重新计数,下一个 batch batch_token_ids, batch_segment_ids, batch_label_ids = [], [], []
class DateProcess(object): def __init__(self, vocab_path, max_len, batch_size): self.vocab_path = vocab_path self.max_len = max_len self.batch_size = batch_size def get_label2id(self, train_data_path, model_output_path): ''' 这个主要实现了将标签与 id 进行替换 :param train_data_path: 数据的路径 :param model_output_path: 模型输出的路径 :return: 返回两个字典,一个是 label-> id 一个是 id-> label ''' label_list = [] with bert4keras_open(train_data_path, "r", encoding="utf-8") as f: for text in f: label_list.append(text.strip().split("@@@@@")[0]) # 将label进行去重 labels = sorted(set(label_list)) id2label = {} label2id = {} for index, label in enumerate(labels): label2id[label] = index id2label[index] = label # 为了以后预测模型时 id 与 label 好对应,所以将 id2label 进行 json 格式的保存 with bert4keras_open(os.path.join(model_output_path + "id2abel_new_new.json"), "w") as f: json.dump(id2label, f, ensure_ascii=False) return label2id, id2label def load_data(self, file_name, label2id): ''' 这个函数用来加载数据,将样本存到一个列表中,每个样本是一个元组 [(label,text),(label,text),(label,text)..........(label,text)] 这里就通过 label2id(字典)将文本的 label 转化为了 id :param file_name: :param label2id: :return: 返回一个列表 ''' data_list = [] with bert4keras_open(file_name, "r", encoding="utf-8") as f: for line in f: label = label2id[line.strip().split("@@@@@")[0]] text = line.strip().split("@@@@@")[1] data_list.append((label, text)) return data_list def generate_data(self, train_data_path, model_output_path): ''' 这个函数就是将 加载数据,label->id,生成 训练集,测试集,验证机生成器的一个函数 :param train_data_path: :param model_output_path: :return: 返回5个参数,label2id, id2label, train_data_generate, vail_data_generate, test_data_generate ''' # 生成 label-> id 的字典 label2id, id2label = self.get_label2id(train_data_path, model_output_path) # 生成 数据的列表 data_list = self.load_data(train_data_path, label2id) length = len(data_list) # 进行数据的打乱 random.shuffle(data_list) # 划分训练集,验证集,测试集 train_data = data_list[:int(0.8 * length)] vail_data = data_list[int(0.8 * length):int(0.9 * length)] test_data = data_list[int(0.9 * length):] # 创建三个数据生成器 train_data_generate = data_generator(train_data, self.max_len, self.batch_size, self.vocab_path) vail_data_generate = data_generator(vail_data, self.max_len, self.batch_size, self.vocab_path) test_data_generate = data_generator(test_data, self.max_len, self.batch_size, self.vocab_path) return label2id, id2label, train_data_generate, vail_data_generate, test_data_generate
if __name__ == '__main__':
data_process = DateProcess("../data/vocab.txt", max_len=128, batch_size=32)
_, _, train, vail, test = data_process.generate_data("../data/result_sample.txt", "../data/output")
for token_and_segment,label in test:
print("*"*1000)
print(token_and_segment[0])
print()
print(token_and_segment[1])
print()
print(label[:, 0])
测试结果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。