当前位置:   article > 正文

Transformer Pytorch实现详解【初学者】_transformer 残差网络 pytorch

transformer 残差网络 pytorch

1. 内容来源

Transformer_哔哩哔哩_bilibiliTransformer是68 Transformer【动手学深度学习v2】的第1集视频,该合集共计4集,视频收藏或关注UP主,及时了解更多相关视频内容。https://www.bilibili.com/video/BV1Kq4y1H7FL?p=1&vd_source=f94822d3eca79b8e245bb58bbced6b77

2. Transformer实现

2.1 多头注意力

 (图源:Transformer_哔哩哔哩_bilibili

上图为多头注意力的示意图,与普通注意力架构不同在于使用并行的多个注意力层,类似CNN中采用多个卷积核堆叠。考虑自注意的多头架构,设kqv的维度为size_kqv,每一个注意力头使用FC将kqv变换到num_hiddens_single维度,则h个头可得到h*num_hiddens_single维度特征,经过一个FC变换到num_output维度。此时,单层多头注意力的各个层可以表示为:

  • W_k_i = nn.Linear(size_kqv, num_hiddens_single, bias=False)
  • W_q_i = nn.Linear(size_kqv, num_hiddens_single, bias=False)
  • W_v_i = nn.Linear(size_kqv, num_hiddens_single, bias=False)
  • W_o = nn.Linear(h*num_hiddens_single, num_output)

为了避免N个注意力头带来计算开销的N倍增长,一般设定num_hiddens_single=num_output/h。此时将num_output写成num_hiddens,将h个W_k_i、W_q_i、W_v_i各自拼接,得到输出维度为h*num_hiddens_single=num_output=num_hiddens。则可以将单层多有注意力的各层简化层四个大矩阵,从而实现并行计算:

  • W_k = nn.Linear(size_kqv, num_hiddens, bias=False)
  • W_q = nn.Linear(size_kqv, num_hiddens, bias=False)
  • W_v = nn.Linear(size_kqv, num_hiddens, bias=False)
  • W_o = nn.Linear(num_hiddens, num_hiddens)

那么多头注意力的“多头”就被封装在大矩阵计算中,矩阵拼接及其逆转代码实现:

  1. import torch
  2. import math
  3. import pandas as pd
  4. from torch import nn
  5. from d2l import torch as d2l
  6. def transpose_qkv(X, num_heads):
  7. # X.shape=(num_batch, num_qkv, num_hiddens)
  8. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
  9. # X.shape=(num_batch, num_qkv, num_heads, num_hiddens_single)
  10. X = X.permute(0, 2, 1, 3)
  11. # X.shape=(num_batch, num_heads, num_qkv, num_hiddens_single)
  12. # return (num_batch * num_heads, num_qkv, num_hiddens_single))
  13. return X.reshape(-1, X.shape[2], X.shape[3])
  14. def transpose_output(X, num_heads):
  15. X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
  16. X = X.permute(0, 2, 1, 3)
  17. return X.reshape(X.shape[0], X.shape[1], -1)

多头注意力代码实现(使用点乘注意力机制):

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, key_size, query_size, value_size, num_hiddens,
  3. num_heads, dropout, bias=False, **kwargs):
  4. super(MultiHeadAttention, self).__init__(**kwargs)
  5. self.num_heads = num_heads
  6. self.attention = d2l.DotProductAttention(dropout)
  7. # self.W_q_i = nn.Linear(query_size, num_hiddens/num_heads, bias=bias)
  8. # 将num_heads个self.W_q_i拼接可得到self.W_q
  9. self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
  10. self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
  11. self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
  12. self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  13. def forward(self, queries, keys, values, valid_lens):
  14. # MultiHeadAttention(queries, keys, values).shape = queries.shape
  15. queries = transpose_qkv(self.W_q(queries) ,self.num_heads)
  16. keys = transpose_qkv(self.W_k(keys) ,self.num_heads)
  17. values = transpose_qkv(self.W_v(values) ,self.num_heads)
  18. if valid_lens is not None:
  19. valid_lens = torch.repeat_interleave(valid_lens,
  20. repeats=self.num_heads,
  21. dim=0)
  22. output = self.attention(queries, keys, values, valid_lens)
  23. output_concat = transpose_output(output, self.num_heads)
  24. return self.W_o(output_concat)

2.2 位置编码

由于自注意力中每一个词与所有词进行注意力计算,所以可以在一次并行计算中得到所有输出,但是放弃顺序操作会导致丢失顺序信息,所以需要在输入X中注入位置信息PP中的元素公式表达:

p_{i,2j}=sin\left ( i/10000^{2j/d} \right )

p_{i,2j+1}=cos\left ( i/10000^{2j/d} \right )

实现代码:

  1. class PositionalEncoding(nn.Module):
  2. def __init__(self, num_hiddens, dropout, max_len=1000):
  3. super(PositionalEncoding, self).__init__()
  4. self.dropout = nn.Dropout(dropout)
  5. self.P = torch.zeros((1, max_len, num_hiddens))
  6. X = torch.arange(max_len, dtype=torch.float32).reshape(
  7. -1, 1) / torch.pow(10000, torch.arange(
  8. 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
  9. self.P[:, :, 0::2] = torch.sin(X)
  10. self.P[:, :, 1::2] = torch.cos(X)
  11. def forward(self, X):
  12. X = X + self.P[:, :X.shape[1], :].to(X.device)
  13. return self.dropout(X)

2.3 基于位置的前馈网络

在Transformer中,多头注意力输出的特征需要经过前馈网络进行变换,输入输出维度不同,中间层把前两个维度融合再分开,代码实现即两个FC:

  1. class PositionWiseFFN(nn.Module):
  2. def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs,
  3. **kwargs):
  4. super(PositionWiseFFN, self).__init__(**kwargs)
  5. self.dense1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens)
  6. self.relu = nn.ReLU()
  7. self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  8. def forward(self, X):
  9. # input.shape=(num_batch, num_queries, num_attention_hiddens)
  10. # output.shape=(num_batch, num_queries, ffn_num_outputs)
  11. return self.dense2(self.relu(self.dense1(X)))

2.4 编码器块

 (图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation

上图为编码器中的块结构,每个块中包含一个多头注意力、两个AddNorm、一个前馈网络。其中AddNorm即封装了残差网络结构和LayerNorm的模块,代码实现:

  1. class AddNorm(nn.Module):
  2. def __init__(self, normalized_shape, dropout, **kwargs):
  3. super(AddNorm, self).__init__(**kwargs)
  4. self.dropout = nn.Dropout(dropout)
  5. self.ln = nn.LayerNorm(normalized_shape)
  6. def forward(self, X, Y):
  7. # output.shape = input.shape
  8. return self.ln(self.dropout(Y) + X)

则将所有模块连接,可得到编码器块:

  1. class EncoderBlock(nn.Module):
  2. def __init__(self, key_size, query_size, value_size, num_hiddens,
  3. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  4. dropout, use_bias=False, **kwargs):
  5. super(EncoderBlock, self).__init__(**kwargs)
  6. self.attention = MultiHeadAttention(key_size, query_size,
  7. value_size, num_hiddens,
  8. num_heads, dropout,
  9. use_bias)
  10. self.addnorm1 = AddNorm(norm_shape, dropout)
  11. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
  12. num_hiddens)
  13. self.addnorm2 = AddNorm(norm_shape, dropout)
  14. def forward(self, X, valid_lens):
  15. Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
  16. return self.addnorm2(Y, self.ffn(Y))

2.5 编码器拼接

  (图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation

上图为编码器的整体架构,原始数据通过embedding层得到词向量表示,与位置编码信息相加得到编码器输入,由于多头注意力和AddNorm均不改变输入输出维度,可以堆叠多个编码器块,并将最后一个编码器块的输出作为整个编码器的输出。代码实现:

  1. class TransformerEncoder(d2l.Encoder):
  2. def __init__(self, vocab_size, key_size, query_size, value_size,
  3. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  4. num_heads, num_layers, dropout, use_bias=False, **kwargs):
  5. super(TransformerEncoder, self).__init__(**kwargs)
  6. self.num_hiddens = num_hiddens
  7. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  8. self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
  9. self.blks = nn.Sequential()
  10. for i in range(num_layers):
  11. self.blks.add_module(
  12. "block" + str(i),
  13. EncoderBlock(key_size, query_size, value_size, num_hiddens,
  14. norm_shape, ffn_num_input, ffn_num_hiddens,
  15. num_heads, dropout, use_bias))
  16. def forward(self, X, valid_lens, *args):
  17. # 将embedding(X)放大到与pos_encoding类似的大小
  18. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  19. self.attention_weights = [None] * len(self.blks)
  20. for i, blk in enumerate(self.blks):
  21. X = blk(X, valid_lens)
  22. self.attention_weights[i] = blk.attention.attention.attention_weights
  23. return X

2.6 解码器块

   (图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation

上图为解码器块的结构,主要由掩蔽多头注意力、多头注意力、前馈网络构成,其中,由于训练时解码器的输入包含未来信息,所以需要在处理输入的注意力模块中加入掩码从而避免看到未来信息,而第二个多头注意力主要以编码器提取的信息为KV对,以第一个注意力的输出为Q进行信息的筛选。代码实现:

  1. class DecoderBlock(nn.Module):
  2. def __init__(self, key_size, query_size, value_size, num_hiddens,
  3. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  4. dropout, i, **kwargs):
  5. super(DecoderBlock, self).__init__(**kwargs)
  6. self.i = i
  7. self.attention1 = MultiHeadAttention(key_size, query_size,
  8. value_size, num_hiddens,
  9. num_heads, dropout)
  10. self.addnorm1 = AddNorm(norm_shape, dropout)
  11. self.attention2 = MultiHeadAttention(key_size, query_size,
  12. value_size, num_hiddens,
  13. num_heads, dropout)
  14. self.addnorm2 = AddNorm(norm_shape, dropout)
  15. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
  16. num_hiddens)
  17. self.addnorm3 = AddNorm(norm_shape, dropout)
  18. def forward(self, X, state):
  19. enc_outputs, enc_valid_lens = state[0], state[1]
  20. if state[2][self.i] is None: # trining
  21. key_values = X
  22. else: # Prediction
  23. key_values = torch.cat((state[2][self.i], X), axis=1)
  24. state[2][self.i] = key_values
  25. if self.training:
  26. batch_size, num_steps, _ = X.shape
  27. # 遮盖未来信息
  28. dec_valid_lens = torch.arange(1, num_steps + 1,
  29. device=X.device).repeat(
  30. batch_size, 1)
  31. else: # prediction
  32. dec_valid_lens = None
  33. X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
  34. Y = self.addnorm1(X, X2)
  35. Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
  36. Z = self.addnorm2(Y, Y2)
  37. return self.addnorm3(Z, self.ffn(Z)), state

其中,self.i和state[2]是用来将每个时间步的输出拼接到输入,dec_valid_lens为从1开始的递增序列,用来代表每次看到的词向量个数,作为掩码。

2.7 编码器拼接

    (图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation

上图为解码器的整体构造,接受输入和编码器提取的特征,经过FC将特征转化成词输出。代码实现:

  1. class TransformerDecoder(d2l.AttentionDecoder):
  2. def __init__(self, vocab_size, key_size, query_size, value_size,
  3. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  4. num_heads, num_layers, dropout, **kwargs):
  5. super(TransformerDecoder, self).__init__(**kwargs)
  6. self.num_hiddens = num_hiddens
  7. self.num_layers = num_layers
  8. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  9. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  10. self.blks = nn.Sequential()
  11. for i in range(num_layers):
  12. self.blks.add_module(
  13. "block" + str(i),
  14. DecoderBlock(key_size, query_size, value_size, num_hiddens,
  15. norm_shape, ffn_num_input, ffn_num_hiddens,
  16. num_heads, dropout, i))
  17. self.dense = nn.Linear(num_hiddens, vocab_size)
  18. def init_state(self, enc_outputs, enc_valid_lens, *args):
  19. return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
  20. def forward(self, X, state):
  21. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  22. self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
  23. for i, blk in enumerate(self.blks):
  24. X, state = blk(X, state)
  25. self._attention_weights[0][i] = blk.attention1.attention.attention_weights
  26. self._attention_weights[1][i] = blk.attention2.attention.attention_weights
  27. return self.dense(X), state
  28. @property
  29. def attention_weights(self):
  30. return self._attention_weights

3. 模型训练

  1. num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
  2. lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
  3. ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
  4. key_size, query_size, value_size = 32, 32, 32
  5. norm_shape = [32]
  6. train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
  7. encoder = TransformerEncoder(
  8. len(src_vocab), key_size, query_size, value_size, num_hiddens,
  9. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  10. num_layers, dropout)
  11. decoder = TransformerDecoder(
  12. len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
  13. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  14. num_layers, dropout)
  15. net = d2l.EncoderDecoder(encoder, decoder)
  16. d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

结果:

loss 0.033, 6039.9 tokens/sec on cuda:0

 测试:

  1. engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
  2. fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
  3. for eng, fra in zip(engs, fras):
  4. translation, dec_attention_weight_seq = d2l.predict_seq2seq(
  5. net, eng, src_vocab, tgt_vocab, num_steps, device, True)
  6. print(f'{eng} => {translation}, ',
  7. f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

效果:

  1. go . => va !, bleu 1.000
  2. i lost . => j'ai perdu ., bleu 1.000
  3. he's calm . => il est calme ., bleu 1.000
  4. i'm home . => je suis chez moi ., bleu 1.000

其中,模型的注意力权重可以通过net.encoder.attention_weights以及net.decoder.attention_weights获取,代码:

  1. enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
  2. -1, num_steps))
  3. d2l.show_heatmaps(
  4. enc_attention_weights.cpu(), xlabel='Key positions',
  5. ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  6. figsize=(7, 3.5))

权重热力图:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/838688
推荐阅读
相关标签
  

闽ICP备14008679号