赞
踩
机器翻译
将序列从一种语言自动翻译成另一种语言
如何将预处理的数据加载到小批量进行训练
1.下载数据集
下载“英-法”数据集
①数据集的每一行都是制表符\t分隔的文本序列对
②序列对由英文序列和翻译后的法语序列组成
③源语言是英语,目标语言是法语
- import os
- import torch
- from d2l import torch as d2l
-
- # 下载和预处理数据集
- # @save
- d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',
- '94646ad1522d915e7b0f9296181140edcf86a4f5')
-
-
- # @save
- def read_data_nmt():
- """载入“英语-法语”数据集"""
- data_dir = d2l.download_extract('fra-eng')
- with open(os.path.join(data_dir, 'fra.txt'), 'r',
- encoding='utf-8') as f:
- return f.read()

2.数据集预处理
①使用空格代替连续的空格
②使用小写替换大写字符,在单词和标点之间加入空格,标点也当作词元
- def preprocess_nmt(text):
- '''数据集'''
-
- def no_space(char, prev_char):
- return char in set(',.!?') and prev_char != ' '
-
- # 使用空格代替连续空格,使用小写替换大写
- text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
- # 在单词和标点符号之间插入空格
- out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
- for i, char in enumerate(text)]
- return ''.join(out)
3.词元化
对文本序列进行词元,每个词元是一个词或者是一个标点符号
- def tokenize_nmt(text, num_examples=None):
- '''词元化 数据集'''
- source, target = [], []
- for i, line in enumerate(text.split('\n')):
- if num_examples and i > num_examples:
- break
- parts = line.split('\t')
- if len(parts) == 2:
- source.append(parts[0].split(' '))
- target.append(parts[1].split(' '))
- return source, target
4.词表
分为源语言和目标语言的两个词表。
①使用单词级词元化时,词表的大小明显大于使用字符级的词表大小。为了解决这个问题,可以将出现次数少于某一个特定值的低频词为未知词元<unk>
②填充词元<pad>,序列的开始词元<bos>,序列的结束词元<eos>
- # 词表--数据集由语言对组成,出现次数少于2次是低频词<unk>,填充词元<pad>,开始词元<bos>,结束词元<eos>
- src_vocab = d2l.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
5.加载数据集
为了提高效率,可以通过截断和填充方式实现处理小批量文本序列进行训练
- # 加载数据集 num_steps固定长度 截断和填充实现文本序列 有相同的长度
- # 长度达不到,填充<pad> 长度超过截取 并且丢弃
- # truncate_pad函数截断或者填充文本序列
- def truncate_pad(line, num_steps, padding_token):
- '''截断或者填充文本序列'''
- if len(line) > num_steps:
- return line[:num_steps] # 截断
- return line + [padding_token] * (num_steps - len(line)) # 填充
- def build_array_nmt(lines, vocab, num_steps):
- '''文本序列转换为小批量'''
- lines = [vocab[l] for l in lines]
- lines = [l + [vocab['<eos>']] for l in lines]
- array = torch.tensor([truncate_pad(
- l, num_steps, vocab['<pad>']) for l in lines
- ])
- valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
- return array, valid_len
6.模型训练
- def load_data_nmt(batch_size, num_steps, num_examples=600):
- '''返回数据集的迭代器和词表'''
- text = preprocess_nmt(read_data_nmt())
- source, target = tokenize_nmt(text, num_examples)
- src_vocab = d2l.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
- tgt_vocab = d2l.Vocab(target, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
- src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
- tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
- data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
- data_iter = d2l.load_array(data_arrays, batch_size)
- return data_iter, src_vocab, tgt_vocab
第一个小批量数据
- train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)
- for X, X_valid_len, Y, Y_valid_len in train_iter:
- print('X:', X.type(torch.int32))
- print('X的有效长度:', X_valid_len)
- print('Y:', Y.type(torch.int32))
- print('Y的有效长度:', Y_valid_len)
- break
总结
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。