赞
踩
原链接:Text classification with the torchtext library — PyTorch Tutorials 1.11.0+cu102 documentation
(1)导入数据集(经常会出现数据集下载失败的情况),有大佬的网盘:https://pan.baidu.com/s/1Rz_XoaTZWSRiHGOwkACosQ,提取码:j0no
下载完直接放到当前打开jupyter notebook的目录下,地址就到AG_NEWS.data文件夹即可
(现在的版本好像要加上root=‘地址’,不然会报错)
- import torch
- from torchtext.datasets import AG_NEWS
- path = r'E:\Notebook\自然语言处\Text_classification_with_the_torchtext_library\AG_NEWS.data'
- train_iter = iter(AG_NEWS(root=path, split='train'))
(2)构建词汇表
- from torchtext.data.utils import get_tokenizer #导入分词工具
- from torchtext.vocab import build_vocab_from_iterator #使用迭代器构建词表
-
- tokenizer = get_tokenizer('basic_english') #创建分词器对象,采用英文分词
- train_iter = AG_NEWS(root=path, split='train') #获取数据集,并生成迭代器
-
- def yield_tokens(data_iter):
- for _, text in data_iter: #获取每一条的标签label和内容text
- yield tokenizer(text) #对获取内容分词,并返回。yield返回一个迭代器对象
-
- #将未能识别的单词设置为<unk>
- vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
-
- #设置<unk>的索引为默认索引,一旦遇到不能识别单词,转为<unk>的索引值
- vocab.set_default_index(vocab['<unk>'])
(3)获取每条数据的label和text
- text_pipeline = lambda x: vocab(tokenizer(x)) #获取每一条的text的索引表示
- label_pipeline = lambda x: int(x) - 1 #获取对应的label
-
- #演示
- text_pipeline('here is the an example')
- >>> [475, 21, 2, 30, 5297]
- label_pipeline('10')
- >>> 9
<
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。