当前位置:   article > 正文

TorchText简介

torchtext

TorchText简介

注:此文为笔者学习 DataWhale 开源教程《深入浅出 Pytorch》所做学习笔记,仅记录平时接触较少且笔者认为价值较高的知识点。

一、包的介绍

​ TorchText 库是专用于自然语言处理的、基于 Pytorch 生态的第三方库,提供了对文本数据进行预处理的各种工具,但是没有如 TorchVision 库一样的各种已定义模型,如需使用定义好的预训练模型,需要使用 transformers 库。

​ TorchText 库可直接使用 pip 安装:

pip install torchtext
  • 1

二、构建数据集

​ TorchText 提供了 Field 对象来定义字段的处理方式,可以针对不同的数据样本定义不同的处理形式。定义 Field 对象,首先需要定义一个分词器,接着根据指定参数分别定义样本和标签的 Field:

tokenize = lambda x: x.split()
# 分词器
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
# 用于样本切分的Field
LABEL = data.Field(sequential=False, use_vocab=False)
# 用于标签切分的Field
'''
sequential:数据是否是顺序表示的,若不是,则不使用分词器处理
tokenize:分词器
lower:是否将字符串全部转为小写
fix_length:将此字段所有实例都将填充到一个固定的长度
use_vocab:是否使用词表,默认为True,若不是则输入数据应为数字类型
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

​ 定义完 Field 之后,需要定义一个构建数据集的函数:

from torchtext import data
from tqdm import tqdm
def get_dataset(csv_data, text_field, label_field, test=False):
    # 用于构建 dataset
    fields = [("id", None), # we won't be needing the id, so we pass in None as the field
                 ("comment_text", text_field), ("toxic", label_field)]   
    # 此处假设读入的csv数据有三列:id——不使用;comment_text——样本文本;toxic——标签
    examples = []
    # 样本集合

    if test:
        # 如果为测试集,则不加载label
        for text in tqdm(csv_data['comment_text']):
            # 读取样本文本,使用tqdm可视化进度
            examples.append(data.Example.fromlist([None, text, None], fields))
            # 利用fromlist函数来实现格式的转换,参数列表分别对应csv数据格式
    else:
        for text, label in tqdm(zip(csv_data['comment_text'], csv_data['toxic'])):
            examples.append(data.Example.fromlist([None, text, label], fields))
    return examples, fields
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

​ 接着使用该函数构建 dataset:

import pandas as pd
train_data = pd.read_csv('train.csv')
valid_data = pd.read_csv('valid.csv')
test_data = pd.read_csv("test.csv")
# 读取csv格式数据
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True)
LABEL = data.Field(sequential=False, use_vocab=False)
# 之前定义的两个Field

# 得到构建Dataset所需的examples和fields
train_examples, train_fields = get_dataset(train_data, TEXT, LABEL)
valid_examples, valid_fields = get_dataset(valid_data, TEXT, LABEL)
test_examples, test_fields = get_dataset(test_data, TEXT, None, test=True)
# 构建Dataset数据集
train = data.Dataset(train_examples, train_fields)
valid = data.Dataset(valid_examples, valid_fields)
test = data.Dataset(test_examples, test_fields)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

​ 构建的 train、valid、test 即可在 Pytorch 中使用。

三、词表及迭代器

​ TorchText 可使用 Field 对象的 build_vocab 函数来构建词表:

TEXT.build_vocab(train)
  • 1

​ 在 TorchText 中,DataLoader 被替换成 Iterator 和 BucketIterator,可通过下文代码使用:

from torchtext.data import Iterator, BucketIterator
# 若只针对训练集构造迭代器
# train_iter = data.BucketIterator(dataset=train, batch_size=8, shuffle=True, sort_within_batch=False, repeat=False)

# 同时对训练集和验证集进行迭代器的构建
train_iter, val_iter = BucketIterator.splits(
        (train, valid), # 构建数据集所需的数据集
        batch_sizes=(8, 8),
        device=-1, # 如果使用gpu,此处将-1更换为GPU的编号
        sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
        sort_within_batch=False
)

test_iter = Iterator(test, batch_size=8, device=-1, sort=False, sort_within_batch=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

​ 其中各个参数同 DataLoader 类似。

四、注意

​ 作为一个开发时间较早、开发时自然语言处理尚未进入预训练模型阶段的第三方库,其没有提供各种预训练模型,而数据预处理的 API 也和 transformers 库中的一些 API 相冲突,因此,更建议目前在自然语言处理尤其是使用预训练模型时直接使用 transformers 库。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/378371
推荐阅读
相关标签
  

闽ICP备14008679号