赞
踩
将https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M数据集下载到本地
import os
import json
from datasets import load_dataset
###设置代理,本地vpn
os.environ["http_proxy"] = "http://127.0.0.1:21882"
os.environ["https_proxy"] = "http://127.0.0.1:21882"
dataset = load_dataset("YeungNLP/firefly-train-1.1M")
dataset.save_to_disk("dataset/Salesforce/dialogstudio") # 保存到该目录下
print(len(dataset['train']))
print(dataset['train'][3])
with open('./dataset/data.json', 'w', encoding='utf-8') as fp:
num = 0
for i in range(len(dataset['train'])):
if dataset['train'][i]['kind'] == 'Couplet':
fp.write(json.dumps({'input': dataset['train'][i]['input'], 'output': dataset['train'][i]['target']},
ensure_ascii=False))
fp.write('\n')
num += 1
print(f"已写入{num}条")
数据类型如下:
上述代码只读取对联数据,结果如下:
from transformers import T5Tokenizer, T5ForConditionalGeneration
os.environ["http_proxy"] = "http://127.0.0.1:21882"
os.environ["https_proxy"] = "http://127.0.0.1:21882"
# 首先,下载并保存tokenizer和模型
tokenizer = T5Tokenizer.from_pretrained("t5-small", cache_dir="./t5_model_v1")
model = T5ForConditionalGeneration.from_pretrained("t5-small", cache_dir="./t5_model_v1")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
结果如下:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
os.environ["http_proxy"] = "http://127.0.0.1:21882"
os.environ["https_proxy"] = "http://127.0.0.1:21882"
tokenizer = T5Tokenizer.from_pretrained("t5_model_v2")
model = T5ForConditionalGeneration.from_pretrained("t5_model_v2")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
下载模型如下配置:
结果如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。