赞
踩
TorchText可以读取三种数据格式:json, tsv (tab separated values 制表分隔值)和csv(comma separated values 逗号分隔值)。
从json开始,你的数据必须是json行格式,也就是说,它必须是这样的:
{
"name": "John", "location": "United Kingdom", "age": 42, "quote": ["i", "love", "the", "united kingdom"]}
{
"name": "Mary", "location": "United States", "age": 36, "quote": ["i", "want", "more", "telescopes"]}
也就是说,每一行都是一个json对象。data/trian.json为例。
然后我们定义字段:
from torchtext import data
from torchtext import datasets
NAME = data.Field()
SAYING = data.Field()
PLACE = data.Field()
接下来,我们必须告诉TorchText哪个字段应用于json对象的哪个元素。
对于json数据,我们必须创建一个字典:
一些注意事项:
fields = {
'name': ('n', NAME), 'location': ('p', PLACE), 'quote': ('s', SAYING)}
现在,在训练循环中,我们可以通过数据迭代器进行迭代并且通过batch.n访问name,通过batch.p访问location,通过batch.s访问quote。
然后我们使用TabularDataset.splits函数创建我们的数据集(train_data和test_data)
path参数指定两个数据集中共同的顶级文件夹,train和test参数指定每个数据集的文件名,例如,这里的train数据集位于data/train.json。
我们告诉函数我们正在使用json数据,并将前面定义的fields字典传递给它。
train_data, test_data = data.TabularDataset.splits(
path = 'data',
train = 'train.json',
test = 'test.json',
format = 'json',
fields = fields
)
如果已经有验证数据集,则可以将其路径作为validation 参数传递。
train_data, valid_data, test_data = data.TabularDataset.splits(
path
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。