赞
踩
我们今天开始分析著名的attention is all you need 论文的pytorch实现的源码解析。
由于项目很大,所以我们会分开几讲来进行讲解。
先上源码:https://github.com/Eathoublu/attention-is-all-you-need-pytorch
大家可以先自行下载并理解。
今天是第一讲,我们先讲解数据的预处理部分:preprocess.py
项目结构:
-transfomer
—__init__.py
—Beam.py
—Constants.py
—Layers.py
—Models.py
—Module.py
—Optim.py
—SubLayers.py
—Translator.py
我们今天先介绍preprocess.py文件即数据预处理的部分,数据清洗以及词表的构建非常重要,我会用注释的方式进行解析。请大家从标号为1的注释开始阅读到标号为12的源码,我将尽可能用简洁的语言解读源码,保证大家都能够读懂。当然,有能力的同学完全可以跳过这一节直接阅读github里面的preprocess.py源码。
好的,现在我们往下翻,翻到def main的位置。
文件源码:
''' Handling the data io ''' import argparse import torch import transformer.Constants as Constants #以上引入了解析命令行参数的库以及pytorch和一个transfomer文件夹下的constants文件 #我们先下翻到def main的位置吧 #1.好了,就是这里。我们首先看main函数: def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-train_src', required=True) parser.add_argument('-train_tgt', required=True) parser.add_argument('-valid_src', required=True) parser.add_argument('-valid_tgt', required=True) parser.add_argument('-save_data', required=True) parser.add_argument('-max_len', '--max_word_seq_len', type=int, default=50) parser.add_argument('-min_word_count', type=int, default=5) parser.add_argument('-keep_case', action='store_true') parser.add_argument('-share_vocab', action='store_true') parser.add_argument('-vocab', default=None) #2.以上是一些命令行运行时需要传入的参数,required=True的字段是必须传入的,其他是可选 opt = parser.parse_args() #3. 解析命令行参数 opt.max_token_seq_len = opt.max_word_seq_len + 2 # include the <s> and </s> # 4.我们在调用的时候参数里面有一个是告诉程序我们传入的句子里面最多有多少个词,然后程序会自动帮我们+2,因为可能存在的</s>标签。至于为什么要传入这个最长词序列的长度我们先不管它。 # Training set # 5.训练集,以下的两行代码的意思是,我们要调用read_instances_from_file这个函数(本来是有的为了方便阅读在这里我不摆出来了。)这个函数的作用是:传入三个参数(数据集的绝对路径、最长的句子里面有多少个词,是否全是小写)函数的主要功能是逐行读入目标文件的内容(文件中一行就是一个句子),并将每行的句子进行分词转换成一个词的列表,并将所有句子的词的列表组合成一个大的句子的列表,例如:[[什么,?,大清,亡,了,?], [我,爱,时崎狂三], [暴走大事件,更新,了], [我, 来自,东北大学],[他,酒驾,进去,了]] 返回值是就是这样的一个列表啦!只不过源码示例示英文的而已~ train_src_word_insts = read_instances_from_file( opt.train_src, opt.max_word_seq_len, opt.keep_case) train_tgt_word_insts = read_instances_from_file( opt.train_tgt, opt.max_word_seq_len, opt.keep_case) #6.这行是做一个规范,规定数据集的数据条数一定要等于标签集的数据条数,否则我们取同样个数的数据集以及标签集,例如100个data,103个target,那么我们data target都取100个。 if len(train_src_word_insts) != len(train_tgt_word_insts): print('[Warning] The training instance count is not equal.') min_inst_count = min(len(train_src_word_insts), len(train_tgt_word_insts)) train_src_word_insts = train_src_word_insts[:min_inst_count] train_tgt_word_insts = train_tgt_word_insts[:min_inst_count] #7.接下来是将那些不合法的数据和标签清洗掉,例如把有数据,标签只是一个空格这样的数据去掉。 #- Remove empty instances train_src_word_insts, train_tgt_word_insts = list(zip(*[ (s, t) for s, t in zip(train_src_word_insts, train_tgt_word_insts) if s and t])) #8.这一步是制作验证集,方法和上面是一样的,都是调用 read_instances_from_file函数,我就不赘述了。 # Validation set valid_src_word_insts = read_instances_from_file( opt.valid_src, opt.max_word_seq_len, opt.keep_case) valid_tgt_word_insts = read_instances_from_file( opt.valid_tgt, opt.max_word_seq_len, opt.keep_case) #9.接下来的7行代码,是和清洗训练集一样,对验证集进行清洗。 if len(valid_src_word_insts) != len(valid_tgt_word_insts): print('[Warning] The validation instance count is not equal.') min_inst_count = min(len(valid_src_word_insts), len(valid_tgt_word_insts)) valid_src_word_insts = valid_src_word_insts[:min_inst_count] valid_tgt_word_insts = valid_tgt_word_insts[:min_inst_count] #- Remove empty instances valid_src_word_insts, valid_tgt_word_insts = list(zip(*[ (s, t) for s, t in zip(valid_src_word_insts, valid_tgt_word_insts) if s and t])) #9.好的,至此我们已经完成了数据清洗的步骤,得到了训练集以及验证集两个部分。现在我们要构建词表了。 # Build vocabulary #10. 请注意,下面的这几个if opt.vocab到else这个代码块在源码示例里面并没有使用到,因为这几个参数都是可选的,我们大可以跳过,暂时不看。请跳到11.的位置继续阅读。 if opt.vocab: predefined_data = torch.load(opt.vocab) assert 'dict' in predefined_data print('[Info] Pre-defined vocabulary found.') src_word2idx = predefined_data['dict']['src'] tgt_word2idx = predefined_data['dict']['tgt'] else: if opt.share_vocab: print('[Info] Build shared vocabulary for source and target.') word2idx = build_vocab_idx( train_src_word_insts + train_tgt_word_insts, opt.min_word_count) src_word2idx = tgt_word2idx = word2idx #11. 10以下,11以上的代码是可选参数的处理,我们可以暂时不去理解,我们假定我们运行程序的时候,没有传入这些参数,那么我们将会进入下面的else,创建一个新的词表。这个build_vocab_idx函数就是用来将词语转化成词表的:原理很简单,就是将刚刚产生的所有的句子列表里面的所有的词给拿出来,并给每一个词一个编号,做成一个字典并返回,这货就叫做词表。 else: print('[Info] Build vocabulary for source.') src_word2idx = build_vocab_idx(train_src_word_insts, opt.min_word_count) print('[Info] Build vocabulary for target.') tgt_word2idx = build_vocab_idx(train_tgt_word_insts, opt.min_word_count) #12.下面,我们将每一个训练集里面出现过的单词转化为词表里面的一个下标index,并将原本是词语序列构成的句子转化为以词语在词表中的下标序列构成的列表。例如:我=1,爱=2,时崎狂三=3,那么原本的句子[我,爱,时崎狂三]就变成[1, 2, 3] 实现这个功能的函数就是convert_instance_to_idx_seq,它的返回值就是上述的这个列表。 # word to index print('[Info] Convert source word instances into sequences of word index.') train_src_insts = convert_instance_to_idx_seq(train_src_word_insts, src_word2idx) valid_src_insts = convert_instance_to_idx_seq(valid_src_word_insts, src_word2idx) print('[Info] Convert target word instances into sequences of word index.') train_tgt_insts = convert_instance_to_idx_seq(train_tgt_word_insts, tgt_word2idx) valid_tgt_insts = convert_instance_to_idx_seq(valid_tgt_word_insts, tgt_word2idx) #12.好了,现在我们构建一个数据集的字典对象,里面包括了传入的参数、词表以及训练集、验证集。然后用torch.save方法持久化这个字典对象,方便以后调用这个数据集进行训练和测试。至此,源码解析的数据预处理部分就结束了。 data = { 'settings': opt, 'dict': { 'src': src_word2idx, 'tgt': tgt_word2idx}, 'train': { 'src': train_src_insts, 'tgt': train_tgt_insts}, 'valid': { 'src': valid_src_insts, 'tgt': valid_tgt_insts}} print('[Info] Dumping the processed data to pickle file', opt.save_data) torch.save(data, opt.save_data) print('[Info] Finish.') if __name__ == '__main__': main()
写在后面:在自然语言处理任务中,数据清洗是非常重要的一步,因此希望大家十分重视,正所谓垃圾in垃圾out。另外,由于本人水平有限,如果我有什么没说明白或者说错了的地方,非常欢迎大家指出,可以留言,另外本人工作邮箱:1012950361@qq.com 我将以最快的速度更正以及补全,谢谢大家!
敬请关注下一期:Attention is all you need pytorch实现 源码解析02 - 模型的训练 (train.py)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。