赞
踩
0. Encoder.onnx 和decoder.onnx
Seq2seq结构也称为encoder-decoder结构,在decoder结构为单步解码时,seq2seq的导出的与只有encoder时(如BERT等)没有差别。当decoder为多步解码时,如生成式任务中的机器翻译、文本摘要等,由于decoder需要多次用到,问题变得稍稍复杂了一点。
由于onnx对多步解码的支持较差,我们选择将模型拆分成encoder和decoder两部分。(官网上有torch.jit.script(model),但对transformers之外的框架不一定适用,所以我们讨论更一般的情况。)
在常见的基于transformer机制的框架中,一般都是先定义各层中的计算机制,如linear,activation (relu, gelu), scaled_multi_head_attention等,再定义各层(layer, T5中的block)的计算,然后再定义encoder, decoder,最后定义整个transformer及对应的train和inferece各状态对应的输入输出。
在定义encoder和decoder的时候,分别加入了src和tgt embedding来进行初始化,有时使用shared embedding, 即src embedding和tgt embedding相同。Encoder和decoder的最大不同在于,encoder中只涉及self-attention的计算,decoder先进行self-attention的计算,再进行cross_attention的计算。
我们只考虑一层的情况,多层情况为一层情况的重复。
在进行inference的时候,为了加速计算,decoder self-attention部分往往把之前计算好的keys和values保存下来,称为past_keys 和past_values。并且每次只输入当前步的token(only 1 token),对应的embedding作为query,计算出当前步对应的key和value, 并分别将其cat到past_keys和past_values中。之后通过self-attention得到当前步的的输出,作为之后进行cross-attention的query。
进行cross attention时,query来自之前的self-attention结果,key和value是由encoder的输出变换得到的。
1. 主要框架 THUMT和transformer的结构
Thumt和transformer等均提供了tensorflow和torch的代码实现。接下来,我们首先以常见框架的torch版本为例,来说明如何生成encoder和decoder的onnx模型。github上有一个onnxt5和fastt5可供参考。需要注意的是,onnxt5实现是没有使用past_key_values, fastt5导出了3个模型,encoder.onnx, decoder-init.onnx和decoder.onnx,并且没有提供beam-search的代码,但对我们理解模型结构,参考输入输出的格式,依然提供了非常好的借鉴。
Fastt5中,decoder-init.onnx只是为了把past_keys, 和past_values创造出来,其实我们完全可以改写代码,在encoder的模型中就加进past_keys, 和past_values,即在encoder模型中把它们创造出来,设置他们的长度为0。在这里我们说的past_keys和past_values是指decoder的self-attention部分用到的past_key_values,为了描述方便,cross attention的past_keys和past_values我们先不考虑。由于past_key_channels和past_values_channels可能不同,所有我们将past_key_channels拆分成past_keys和past_values.
到此,我们就得到了encoder对应的onnx模型,
输入为:features_source, features_mask
输出为:encoder_hidden_states, past_keys, past_values
其中features_source是输入的一个batch的token ids, features_mask是根据一个batch内各句的长度生成的mask,features_mask也可以用source_lengths代替(thumt-tf)。
接下来decoder onnx模型的导出也是类似的,值得注意的是,由于使用了past_key_values, 所有每次输入的序列长度为1,并且在解码时,past_keys和past_values每解码一个token,长度增加1。这里decoder onnx模型的输入和输出分别为:
Decoder input:
features_source, features_mask, encoder_hidden_states, past_keys, past_values, curr_seqs
Decoder_output:
log_prob, present_keys, present_values
其中,如果features_mask信息已经存在时,features_source可以去掉,但为了和原来模型的输入保持一致,并且考虑到内存占用其实并不大,所以带着影响也不大。
Curr_seqs为已经解码出来的序列。
Log_prob的形状为(batc_size, vocab_size),用于在词表空间进行beam_search.
Present_keys为past_keys cat当前步的key得到的,所以经过一步解码后,present_keys比past_keys的长度大1. Present_values与present_keys情况类似。
在得到encoder和decoder之后,就可以通过重写beam-search来进行解码操作。(beam search这里我们就先不详细讨论了)。
如果想将decoder cross_attention的past_key和past_value添加进来,情况和decoder self_attention是类似的,值得注意的是,decoder cross_attention 的past_key_values经过第一步解码生成之后,后面就不会再变化。因此后续的解码步骤不需要再对其进行更新。
2. tf2onnx和torch2onnx
以上主要讨论了torch的模型转为onnx,tf的模型转换为onnx是类似的,只是要先转换为pb模型,再由pb模型通过工具(pb2onnx)转换为onnx。
tf模型转换为pb也是根据需要的输入输出对之前定义的模型结构进行调整,并分别导出encoder和decoder部分的pb。
最后,我们说一下可能遇到的问题,或者说印象深刻的问题,torch转onnx的时候,可能会遇到某些算符在onnx中没有,需要重写,例如triu(三角矩阵),另外需要注意,tensor的seq_length如果长度不同,则dynamic_axis需要用不同的名字(很重要,否则会卡很久)。
tf的pb转onnx时,也可能遇到某些slice不支持,如[:, :, -1:]等,可以考虑先取[:, :, -1], 再expand_dim。这个表面上说起来容易,实际上只能靠推理得知[:, :, -1:]不work(使用的opset=11)。另外一个问题是,tf对onnx的支持也不太好,我自己实践中是在原有的tf框架里面用torch重新实现了beam-search。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。