赞
踩
DefaultDataCollator(return_tensors: str = 'pt')
默认的数据收集器,只是将transformers中的Dataset数据对象转换成tensorflow或pytorch可以处理的Dataset数据对象。没有像DataCollatorWithPadding那样,在转换数据类型的同时,也进行数据的填充。
参数return_tensors表示返回数据的类型。有三个可选项,分别是"tf"、“pt”、“np”,分别表示tensorflow可以处理的数据类型,pytorch可以处理的数据类型以及numpy数据类型。
def preprocess_fn(data): data = {k: sum(data[k], []) for k in data.keys()} total_length = len(data[list(data.keys())[0]]) total_length = (total_length // 128) * 128 # 128表示每一组的句子的长度 result = {k: [v[i: i + 128] for i in range(0, total_length, 128)] for k, v in data.items()} result["label"] = result["input_ids"].copy() return result dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1") tokenizer = transformers.AutoTokenizer.from_pretrained("distilgpt2") # 使用默认的数据收集器 data_collator = transformers.DefaultDataCollator(return_tensors="tf") dataset = dataset.map(function=lambda data: tokenizer(data["text"], truncation=True), batched=True, batch_size=1000, remove_columns=["text"]) dataset = dataset.map(function=preprocess_fn, batched=True, batch_size=1000) train_dataset = dataset["train"].to_tf_dataset(columns=["input_ids", "attention_mask"], batch_size=16, shuffle=True, collate_fn=data_collator, label_cols=["labels"])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。