赞
踩
给定数据集格式为[ 英文+"\t" + 中文
He knows better than to marry her. 他聰明到不會娶她。
He had hoped to succeed, but he didn't. 他本希望可以成功,但是他没有。
分割英文和中文分别到en_list和cn_list
train_file = 'data/translate_train.txt'
dev_file = 'data/translate_dev.txt'
def load_data(filename):
cn = []
en = []
num_examples = 0
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split('\t')
en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ['EOS'])
cn.append(["BOS"] + [c for c in line[1]] + ['EOS'])
return en, cn
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)
分别建立中英文字典
from collections import Counter
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
word_count = Counter()
for sentence in sentences:
for s in sentence:
word_count[s] += 1
ls = word_count.most_common(max_words)
total_words = len(ls) + 2
word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
word_dict['UNK'] = UNK_IDX
word_dict['PAD'] = PAD_IDX
return word_dict, total_words
en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}
根据建立的字典替换en_list和cn_list的中文和英文为数字
替换成数字之后需要安装英文字符的长度进行排序操作,根据sort_by_len决定是否排序操作。
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
length = len(en_sentences)
out_en_sentences = [[en_dict.get(w, UNK_IDX) for w in sent] for sent in en_sentences]
out_cn_sentences = [[cn_dict.get(w, UNK_IDX) for w in sent] for sent in cn_sentences]
def len_argsort(seq, descending=False):
return sorted(range(len(seq)), key=lambda x: len(seq[x]), reverse=descending)
if sort_by_len:
sorted_index = len_argsort(out_en_sentences)
out_en_sentences = [out_en_sentences[i] for i in sorted_index]
out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
return out_en_sentences, out_cn_sentences
train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
划分多个batch
划分数据返回的数据格式为列表,列表成员为元组类型,元组成为为(X_data, X_len, Y_data, Y_len), 其中X_data表示batchSize个英文句子,Y_data为对应的中文翻译。
X_data.shape=(batchSize, en_seq)
x_len.shape=(batchSize, en_seq_len)
y_data.shape=(batchSize, cn_seq)
y_len.shape=(batchSize, cn_seq_len)
(1) 生成batch
指定数据个数和batchSize划分多个batch组,返回数据格式为list,list成员为由索引构成的list。
import numpy as np
def get_mini_batches(n, batch_size, shuffle=False):
idx_list = np.arange(0, n, batch_size) # [0, 1, ..., n-1]
if shuffle:
np.random.shuffle(idx_list)
mini_batches = []
for idx in idx_list:
mini_batches.append(np.arange(idx, min(idx + batch_size, n)))
return mini_batches
data_len = 100
batch_size = 12
"""
[ array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
array([12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]),
...
array([84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]),
array([96, 97, 98, 99])
]
"""
print(get_mini_batches(data_len, batch_size))
(2) 统一每个batch的英文句子的长度,进行补零操作
import numpy as np
seqs = [[1, 3, 4, 12, 3],
[2, 34, 3, 1],
[12, 34, 1],
[2]]
def prepare_data(seqs):
lengths = [len(seq) for seq in seqs]
n_samples = len(seqs)
max_len = np.max(lengths)
x = np.zeros((n_samples, max_len)).astype('int32')
x_lengths = np.array(lengths).astype("int32")
for idx, seq in enumerate(seqs):
x[idx, :lengths[idx]] = seq
return x, x_lengths #x_mask
"""
x = [[ 1 3 4 12 3]
[ 2 34 3 1 0]
[12 34 1 0 0]
[ 2 0 0 0 0]]
x_len = [5 4 3 1]
"""
x, x_len = prepare_data(seqs)
print(x)
print(x_len)
(3) 构造数据集
en_sentences: 英文的list列表类型,成员为由字符索引构成的list
cn_sentences: 中文的list列表类型,成员为由字符索引构成的list
def gen_examples(en_sentences, cn_sentences, batch_size):
minibatches = get_mini_batches(len(en_sentences), batch_size)
all_ex = []
for minibatch in minibatches:
mb_en_sentences = [en_sentences[t] for t in minibatch]
mb_cn_sentences = [cn_sentences[t] for t in minibatch]
mb_x, mb_x_len = prepare_data(mb_en_sentences)
mb_y, mb_y_len = prepare_data(mb_cn_sentences)
all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
return all_ex
数据集下载:
数据集链接:https://pan.baidu.com/s/1RgmRv80zQB71HSze8bQvwA
提取码:ih2c
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。