赞
踩
用深度学习做nlp也有一段时间了,熟悉这块内容的同学都知道,实践算法的时候,写模型是个简单的事,最麻烦的是数据处理,数据处理不仅会浪费我们大部分时间,而且会消耗很大的计算资源,浪费人力物力。今年开始接触pytorch,简洁的API,动态图,更加灵活的编写模式,诸多优点不用多说。最近尝试使用torchtext工具,这里想先说明的是,torchtext并不是pytorch所独有的,使用其它深度学习框架,torchtext仍然可以使用。但是比较麻烦的是,并没有很好很全面的torchtext教程,给同学们入门造成了一定麻烦,这也是我写这篇文章的目的。
首先整体介绍一下torchtext的组件。
torchtext包含以下组件:
Field 包含一写文本处理的通用参数的设置,同时还包含一个词典对象,可以把文本数据表示成数字类型,进而可以把文本表示成需要的tensor类型
以下是Field对象包含的参数:
简单的栗子如下,建一个Field对象
TEXT = data.Field(tokenize=data.get_tokenizer('spacy'),
init_token='<SOS>', eos_token='<EOS>',lower=True)
torchtext的Dataset是继承自pytorch的Dataset,提供了一个可以下载压缩数据并解压的方法(支持.zip, .gz, .tgz)
splits方法可以同时读取训练集,验证集,测试集
TabularDataset可以很方便的读取CSV, TSV, or JSON格式的文件,例子如下:
train, val, test = data.TabularDataset.splits(
path='./data/', train='train.tsv',
validation='val.tsv', test='test.tsv', format='tsv',
fields=[('Text', TEXT), ('Label', LABEL)])
加载数据后可以建立词典,建立词典的时候可以使用与训练的word vector
TEXT.build_vocab(train, vectors="glove.6B.100d")
Iterator是torchtext到模型的输出,它提供了我们对数据的一般处理方式,比如打乱,排序,等等,可以动态修改batch大小,这里也有splits方法 可以同时输出训练集,验证集,测试集
参数如下:
使用方式如下:
train_iter, val_iter, test_iter = data.Iterator.splits(
(train, val, test), sort_key=lambda x: len(x.Text),
batch_sizes=(32, 256, 256), device=-1)
torchtext提供常用文本数据集,并可以直接加载使用:
train,val,test = datasets.WikiText2.splits(text_field=TEXT)
现在包含的数据集包括:
完整例子如下,短短几行就把词典和数据batch做好了。
import jieba import torch from torchtext import data, datasets regex = re.compile(r'[^\u4e00-\u9fa5aA-Za-z0-9]') def tokenizer(text): # create a tokenizer function text = regex.sub(' ', text) return [word for word in jieba.cut(text) if word.strip()] text = data.Field(sequential=True, tokenize=tokenizer, fix_length=150) label = data.Field(sequential=False, use_vocab=False) train, val = data.TabularDataset.splits( path='./data/', train='/brucewu/projects/pytorch_tutorials/chinese_text_cnn/data/train.tsv', validation='/brucewu/projects/pytorch_tutorials/chinese_text_cnn/data/dev.tsv', format='tsv', fields=[('text', text), ('label', label)]) text.build_vocab(train, val, vectors=Vectors(name="/brucewu/projects/pytorch_tutorials/chinese_text_cnn/data/eco_article.vector")) train_iter, val_iter = data.Iterator.splits( (train, val), sort_key=lambda x: len(x.text), batch_sizes=(32, 256), device=-1) vocab = text.vocab
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。