赞
踩
(图源:Transformer_哔哩哔哩_bilibili)
上图为多头注意力的示意图,与普通注意力架构不同在于使用并行的多个注意力层,类似CNN中采用多个卷积核堆叠。考虑自注意的多头架构,设kqv的维度为size_kqv,每一个注意力头使用FC将kqv变换到num_hiddens_single维度,则h个头可得到h*num_hiddens_single维度特征,经过一个FC变换到num_output维度。此时,单层多头注意力的各个层可以表示为:
为了避免N个注意力头带来计算开销的N倍增长,一般设定num_hiddens_single=num_output/h。此时将num_output写成num_hiddens,将h个W_k_i、W_q_i、W_v_i各自拼接,得到输出维度为h*num_hiddens_single=num_output=num_hiddens。则可以将单层多有注意力的各层简化层四个大矩阵,从而实现并行计算:
那么多头注意力的“多头”就被封装在大矩阵计算中,矩阵拼接及其逆转代码实现:
- import torch
- import math
- import pandas as pd
- from torch import nn
- from d2l import torch as d2l
-
- def transpose_qkv(X, num_heads):
- # X.shape=(num_batch, num_qkv, num_hiddens)
- X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
- # X.shape=(num_batch, num_qkv, num_heads, num_hiddens_single)
- X = X.permute(0, 2, 1, 3)
- # X.shape=(num_batch, num_heads, num_qkv, num_hiddens_single)
- # return (num_batch * num_heads, num_qkv, num_hiddens_single))
- return X.reshape(-1, X.shape[2], X.shape[3])
-
- def transpose_output(X, num_heads):
- X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
- X = X.permute(0, 2, 1, 3)
- return X.reshape(X.shape[0], X.shape[1], -1)

多头注意力代码实现(使用点乘注意力机制):
- class MultiHeadAttention(nn.Module):
- def __init__(self, key_size, query_size, value_size, num_hiddens,
- num_heads, dropout, bias=False, **kwargs):
- super(MultiHeadAttention, self).__init__(**kwargs)
- self.num_heads = num_heads
- self.attention = d2l.DotProductAttention(dropout)
- # self.W_q_i = nn.Linear(query_size, num_hiddens/num_heads, bias=bias)
- # 将num_heads个self.W_q_i拼接可得到self.W_q
- self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
- self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
- self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
- self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
-
- def forward(self, queries, keys, values, valid_lens):
- # MultiHeadAttention(queries, keys, values).shape = queries.shape
- queries = transpose_qkv(self.W_q(queries) ,self.num_heads)
- keys = transpose_qkv(self.W_k(keys) ,self.num_heads)
- values = transpose_qkv(self.W_v(values) ,self.num_heads)
-
- if valid_lens is not None:
- valid_lens = torch.repeat_interleave(valid_lens,
- repeats=self.num_heads,
- dim=0)
- output = self.attention(queries, keys, values, valid_lens)
- output_concat = transpose_output(output, self.num_heads)
- return self.W_o(output_concat)

由于自注意力中每一个词与所有词进行注意力计算,所以可以在一次并行计算中得到所有输出,但是放弃顺序操作会导致丢失顺序信息,所以需要在输入中注入位置信息
。
中的元素公式表达:
实现代码:
- class PositionalEncoding(nn.Module):
- def __init__(self, num_hiddens, dropout, max_len=1000):
- super(PositionalEncoding, self).__init__()
- self.dropout = nn.Dropout(dropout)
- self.P = torch.zeros((1, max_len, num_hiddens))
- X = torch.arange(max_len, dtype=torch.float32).reshape(
- -1, 1) / torch.pow(10000, torch.arange(
- 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
- self.P[:, :, 0::2] = torch.sin(X)
- self.P[:, :, 1::2] = torch.cos(X)
-
- def forward(self, X):
- X = X + self.P[:, :X.shape[1], :].to(X.device)
- return self.dropout(X)
在Transformer中,多头注意力输出的特征需要经过前馈网络进行变换,输入输出维度不同,中间层把前两个维度融合再分开,代码实现即两个FC:
- class PositionWiseFFN(nn.Module):
- def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs,
- **kwargs):
- super(PositionWiseFFN, self).__init__(**kwargs)
- self.dense1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens)
- self.relu = nn.ReLU()
- self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
-
- def forward(self, X):
- # input.shape=(num_batch, num_queries, num_attention_hiddens)
- # output.shape=(num_batch, num_queries, ffn_num_outputs)
- return self.dense2(self.relu(self.dense1(X)))
(图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation)
上图为编码器中的块结构,每个块中包含一个多头注意力、两个AddNorm、一个前馈网络。其中AddNorm即封装了残差网络结构和LayerNorm的模块,代码实现:
- class AddNorm(nn.Module):
- def __init__(self, normalized_shape, dropout, **kwargs):
- super(AddNorm, self).__init__(**kwargs)
- self.dropout = nn.Dropout(dropout)
- self.ln = nn.LayerNorm(normalized_shape)
-
- def forward(self, X, Y):
- # output.shape = input.shape
- return self.ln(self.dropout(Y) + X)
则将所有模块连接,可得到编码器块:
- class EncoderBlock(nn.Module):
- def __init__(self, key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
- dropout, use_bias=False, **kwargs):
- super(EncoderBlock, self).__init__(**kwargs)
- self.attention = MultiHeadAttention(key_size, query_size,
- value_size, num_hiddens,
- num_heads, dropout,
- use_bias)
- self.addnorm1 = AddNorm(norm_shape, dropout)
- self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
- num_hiddens)
- self.addnorm2 = AddNorm(norm_shape, dropout)
-
- def forward(self, X, valid_lens):
- Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
- return self.addnorm2(Y, self.ffn(Y))

(图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation)
上图为编码器的整体架构,原始数据通过embedding层得到词向量表示,与位置编码信息相加得到编码器输入,由于多头注意力和AddNorm均不改变输入输出维度,可以堆叠多个编码器块,并将最后一个编码器块的输出作为整个编码器的输出。代码实现:
- class TransformerEncoder(d2l.Encoder):
- def __init__(self, vocab_size, key_size, query_size, value_size,
- num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
- num_heads, num_layers, dropout, use_bias=False, **kwargs):
- super(TransformerEncoder, self).__init__(**kwargs)
- self.num_hiddens = num_hiddens
- self.embedding = nn.Embedding(vocab_size, num_hiddens)
- self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
- self.blks = nn.Sequential()
- for i in range(num_layers):
- self.blks.add_module(
- "block" + str(i),
- EncoderBlock(key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens,
- num_heads, dropout, use_bias))
-
- def forward(self, X, valid_lens, *args):
- # 将embedding(X)放大到与pos_encoding类似的大小
- X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
- self.attention_weights = [None] * len(self.blks)
- for i, blk in enumerate(self.blks):
- X = blk(X, valid_lens)
- self.attention_weights[i] = blk.attention.attention.attention_weights
- return X

(图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation)
上图为解码器块的结构,主要由掩蔽多头注意力、多头注意力、前馈网络构成,其中,由于训练时解码器的输入包含未来信息,所以需要在处理输入的注意力模块中加入掩码从而避免看到未来信息,而第二个多头注意力主要以编码器提取的信息为KV对,以第一个注意力的输出为Q进行信息的筛选。代码实现:
- class DecoderBlock(nn.Module):
- def __init__(self, key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
- dropout, i, **kwargs):
- super(DecoderBlock, self).__init__(**kwargs)
- self.i = i
- self.attention1 = MultiHeadAttention(key_size, query_size,
- value_size, num_hiddens,
- num_heads, dropout)
- self.addnorm1 = AddNorm(norm_shape, dropout)
- self.attention2 = MultiHeadAttention(key_size, query_size,
- value_size, num_hiddens,
- num_heads, dropout)
- self.addnorm2 = AddNorm(norm_shape, dropout)
- self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
- num_hiddens)
- self.addnorm3 = AddNorm(norm_shape, dropout)
-
- def forward(self, X, state):
- enc_outputs, enc_valid_lens = state[0], state[1]
- if state[2][self.i] is None: # trining
- key_values = X
- else: # Prediction
- key_values = torch.cat((state[2][self.i], X), axis=1)
- state[2][self.i] = key_values
- if self.training:
- batch_size, num_steps, _ = X.shape
- # 遮盖未来信息
- dec_valid_lens = torch.arange(1, num_steps + 1,
- device=X.device).repeat(
- batch_size, 1)
- else: # prediction
- dec_valid_lens = None
-
- X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
- Y = self.addnorm1(X, X2)
- Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
- Z = self.addnorm2(Y, Y2)
- return self.addnorm3(Z, self.ffn(Z)), state

其中,self.i和state[2]是用来将每个时间步的输出拼接到输入,dec_valid_lens为从1开始的递增序列,用来代表每次看到的词向量个数,作为掩码。
(图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation)
上图为解码器的整体构造,接受输入和编码器提取的特征,经过FC将特征转化成词输出。代码实现:
- class TransformerDecoder(d2l.AttentionDecoder):
- def __init__(self, vocab_size, key_size, query_size, value_size,
- num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
- num_heads, num_layers, dropout, **kwargs):
- super(TransformerDecoder, self).__init__(**kwargs)
- self.num_hiddens = num_hiddens
- self.num_layers = num_layers
- self.embedding = nn.Embedding(vocab_size, num_hiddens)
- self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
- self.blks = nn.Sequential()
- for i in range(num_layers):
- self.blks.add_module(
- "block" + str(i),
- DecoderBlock(key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens,
- num_heads, dropout, i))
- self.dense = nn.Linear(num_hiddens, vocab_size)
-
- def init_state(self, enc_outputs, enc_valid_lens, *args):
- return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
-
- def forward(self, X, state):
- X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
- self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
- for i, blk in enumerate(self.blks):
- X, state = blk(X, state)
- self._attention_weights[0][i] = blk.attention1.attention.attention_weights
- self._attention_weights[1][i] = blk.attention2.attention.attention_weights
- return self.dense(X), state
-
- @property
- def attention_weights(self):
- return self._attention_weights

- num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
- lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
- ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
- key_size, query_size, value_size = 32, 32, 32
- norm_shape = [32]
-
- train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
-
- encoder = TransformerEncoder(
- len(src_vocab), key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
- num_layers, dropout)
- decoder = TransformerDecoder(
- len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
- norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
- num_layers, dropout)
- net = d2l.EncoderDecoder(encoder, decoder)
- d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

结果:
loss 0.033, 6039.9 tokens/sec on cuda:0
测试:
- engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
- fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
- for eng, fra in zip(engs, fras):
- translation, dec_attention_weight_seq = d2l.predict_seq2seq(
- net, eng, src_vocab, tgt_vocab, num_steps, device, True)
- print(f'{eng} => {translation}, ',
- f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
效果:
- go . => va !, bleu 1.000
- i lost . => j'ai perdu ., bleu 1.000
- he's calm . => il est calme ., bleu 1.000
- i'm home . => je suis chez moi ., bleu 1.000
其中,模型的注意力权重可以通过net.encoder.attention_weights以及net.decoder.attention_weights获取,代码:
- enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
- -1, num_steps))
-
-
- d2l.show_heatmaps(
- enc_attention_weights.cpu(), xlabel='Key positions',
- ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
- figsize=(7, 3.5))
权重热力图:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。