当前位置:   article > 正文








seq2seq任务(如翻译、对话等),输入一个序列得到一个序列,常用RNN实现,单向RNN可以看到序列当前位置及以前的内容、得到当前结果,双向RNN则可以看遍整个序列得到每一个结果。但是RNN不利于并行化,高度的并行化CNN又不能看遍序列,所以提出了self-attention layer.常见attention functions分为additive attention和dot-product attention,后者即softmax(QKT)V,便于矩阵并行化所以更优。


 如图,是transformer的结构,由一个encoder(左)和一个decoder(右)组成,而encoder、decoder内部包含了attention layer,除了attention layer还有FFN(feed forward network,即线性+激活)







怎么将transformer用在CV里面呢?就是把图像也看作序列一样,将H x W x F的图像改为HW x F的维度,即可看作HW个F维向量。对于高像素图,通过打patch的操作将原图切作更小像素级的图像块送入attention层中进行计算。

4、Position Encoding

由于self-attention没有position information,所以加上通过学习得到的PE(维度和embedding的维度一样)来区分query的不同位置,或者说对相对位置进行表达。NLP中这个向量能决定当前词的位置,或者说在一个句子中不同的词之间的距离;CV中就是确定某个像素点的所在行列位置。



  1. multihead_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout)
  2. out, attention_map = multihead_attn(query, key, value, attn_mask, key_padding_mask)



默认情况下q,k,v的embedding维度需要一样,即q,k的维度分别为(L, B, embed_dim)与(S, B, embed_dim),否则应当输入特殊的参数kdim和vdim,即k维度为(S, B, kdim).


  1. class MultiheadAttention(Module):
  2. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
  3. kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
  4. factory_kwargs = {'device': device, 'dtype': dtype}
  5. super(MultiheadAttention, self).__init__()
  6. self.embed_dim = embed_dim
  7. self.kdim = kdim if kdim is not None else embed_dim
  8. self.vdim = vdim if vdim is not None else embed_dim
  9. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  10. self.num_heads = num_heads
  11. self.dropout = dropout
  12. self.batch_first = batch_first
  13. self.head_dim = embed_dim // num_heads
  14. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  15. # embed_dim必须能被num_heads整除
  16. # 以下proj_weight记作A
  17. if self._qkv_same_embed_dim is False:
  18. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
  19. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
  20. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
  21. self.register_parameter('in_proj_weight', None)
  22. else:
  23. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
  24. self.register_parameter('q_proj_weight', None)
  25. self.register_parameter('k_proj_weight', None)
  26. self.register_parameter('v_proj_weight', None)
  27. # 以下in_proj_bias记作b
  28. if bias:
  29. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  30. else:
  31. self.register_parameter('in_proj_bias', None)
  32. # forward过程第一步按公式y = xA^T + b对输入的q,k,v进行线性变换


  1. self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  2. if add_bias_kv:
  3. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  4. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  5. else:
  6. self.bias_k = self.bias_v = None
  7. self.add_zero_attn = add_zero_attn
  8. self._reset_parameters()
  9. # 让attention要训的参数好好初始化
  10. def _reset_parameters(self):
  11. if self._qkv_same_embed_dim:
  12. xavier_uniform_(self.in_proj_weight)
  13. else:
  14. xavier_uniform_(self.q_proj_weight)
  15. xavier_uniform_(self.k_proj_weight)
  16. xavier_uniform_(self.v_proj_weight)
  17. if self.in_proj_bias is not None:
  18. constant_(self.in_proj_bias, 0.)
  19. constant_(self.out_proj.bias, 0.)
  20. if self.bias_k is not None:
  21. xavier_normal_(self.bias_k)
  22. if self.bias_v is not None:
  23. xavier_normal_(self.bias_v)
  24. def __setstate__(self, state):
  25. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  26. if '_qkv_same_embed_dim' not in state:
  27. state['_qkv_same_embed_dim'] = True
  28. super(MultiheadAttention, self).__setstate__(state)



  1. def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
  2. need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
  3. r"""
  4. Args:
  5. # 如果__init__时batch_first=True,则q、k、v的B在第一个维度,否则如下所示:
  6. query: (L, B, E_q)
  7. key: (S, B, E_k)
  8. value: (S, B, E_v)
  9. key_padding_mask: If specified, a mask of (B, S) indicating which elements within key to ignore for the purpose of attention. Binary and byte masks are supported.
  10. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
  11. need_weights: Default: True.
  12. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S) or (B\cdot\text{num\_heads}, L, S).
  13. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  14. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
  15. Outputs:
  16. - **attn_output** - Attention outputs (L, B, E)或(B, L, E)
  17. - **attn_output_weights** - Attention output weights (B, L, S) when need_weights=True
  18. """


  1. if self.batch_first:
  2. query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
  3. # 保证q,k,v的B一直在第二个维度
  4. # 将前面初始化好的权重偏差以及Q/K/V、mask传入更底层的函数,得到输出的最终结果和中间结果
  5. if not self._qkv_same_embed_dim:
  6. # 如果要用embedding维度不一样的k/v,需要令F.multi_head_attention_forward
  7. # 的输入参数use_separate_proj_weight=True
  8. attn_output, attn_output_weights = F.multi_head_attention_forward(
  9. query, key, value, self.embed_dim, self.num_heads,
  10. self.in_proj_weight, self.in_proj_bias,
  11. self.bias_k, self.bias_v, self.add_zero_attn,
  12. self.dropout, self.out_proj.weight, self.out_proj.bias,
  13. training=self.training,
  14. key_padding_mask=key_padding_mask, need_weights=need_weights,
  15. attn_mask=attn_mask, use_separate_proj_weight=True,
  16. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  17. v_proj_weight=self.v_proj_weight)
  18. else:
  19. attn_output, attn_output_weights = F.multi_head_attention_forward(
  20. query, key, value, self.embed_dim, self.num_heads,
  21. self.in_proj_weight, self.in_proj_bias,
  22. self.bias_k, self.bias_v, self.add_zero_attn,
  23. self.dropout, self.out_proj.weight, self.out_proj.bias,
  24. training=self.training,
  25. key_padding_mask=key_padding_mask, need_weights=need_weights,
  26. attn_mask=attn_mask)
  27. if self.batch_first:
  28. return attn_output.transpose(1, 0), attn_output_weights
  29. else:
  30. return attn_output, attn_output_weights



  1. def multi_head_attention_forward(
  2. query: Tensor,
  3. key: Tensor,
  4. value: Tensor,
  5. embed_dim_to_check: int,
  6. num_heads: int,
  7. in_proj_weight: Tensor,
  8. in_proj_bias: Optional[Tensor],
  9. bias_k: Optional[Tensor],
  10. bias_v: Optional[Tensor],
  11. add_zero_attn: bool,
  12. dropout_p: float,
  13. out_proj_weight: Tensor,
  14. out_proj_bias: Optional[Tensor],
  15. training: bool = True,
  16. key_padding_mask: Optional[Tensor] = None,
  17. need_weights: bool = True,
  18. attn_mask: Optional[Tensor] = None,
  19. use_separate_proj_weight: bool = False,
  20. q_proj_weight: Optional[Tensor] = None,
  21. k_proj_weight: Optional[Tensor] = None,
  22. v_proj_weight: Optional[Tensor] = None,
  23. static_k: Optional[Tensor] = None,
  24. static_v: Optional[Tensor] = None,
  25. ) -> Tuple[Tensor, Optional[Tensor]]:
  26. tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
  27. # 这个看不懂,暂且跳过不影响
  28. if has_torch_function(tens_ops):
  29. return handle_torch_function(
  30. multi_head_attention_forward,
  31. tens_ops,
  32. query,
  33. key,
  34. value,
  35. embed_dim_to_check,
  36. num_heads,
  37. in_proj_weight,
  38. in_proj_bias,
  39. bias_k,
  40. bias_v,
  41. add_zero_attn,
  42. dropout_p,
  43. out_proj_weight,
  44. out_proj_bias,
  45. training=training,
  46. key_padding_mask=key_padding_mask,
  47. need_weights=need_weights,
  48. attn_mask=attn_mask,
  49. use_separate_proj_weight=use_separate_proj_weight,
  50. q_proj_weight=q_proj_weight,
  51. k_proj_weight=k_proj_weight,
  52. v_proj_weight=v_proj_weight,
  53. static_k=static_k,
  54. static_v=static_v,
  55. )
  56. # set up shape vars
  57. tgt_len, bsz, embed_dim = query.shape
  58. src_len, _, _ = key.shape
  59. assert embed_dim == embed_dim_to_check, \
  60. f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
  61. if isinstance(embed_dim, torch.Tensor):
  62. # embed_dim can be a tensor when JIT tracing
  63. head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
  64. else:
  65. head_dim = embed_dim // num_heads
  66. assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
  67. # 这里再次规定了embed_dim必需能被num_heads整除,否则会报错。


  1. # use_separate_proj_weight=True时,是不同embedding维度的kv输入
  2. if use_separate_proj_weight:
  3. # allow MHA to have different embedding dimensions when separate projection weights are used
  4. assert key.shape[:2] == value.shape[:2], \
  5. f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
  6. # kv的sequence和batch维度必须一致,即前两个维度必须都是(S,B)
  7. else:
  8. assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
  9. # 计算in-projection
  10. if not use_separate_proj_weight:
  11. q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  12. else:
  13. assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
  14. assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
  15. assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
  16. if in_proj_bias is None:
  17. b_q = b_k = b_v = None
  18. else:
  19. b_q, b_k, b_v = in_proj_bias.chunk(3)
  20. q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

(2.1)_in_projection_packed和 _in_projection函数


  1. def _in_projection_packed(
  2. q: Tensor,
  3. k: Tensor,
  4. v: Tensor,
  5. w: Tensor,
  6. b: Optional[Tensor] = None,
  7. ) -> List[Tensor]:
  8. E = q.size(-1)
  9. if k is v:
  10. if q is k:
  11. # q=k=v,做的是self-attention,可以直接将(3E,E)的权重矩阵w与(L,B,E)的q送入linear,
  12. # linear的做法是q*w^T+b,所以得到结果(L,B,3E)。
  13. # 再用chunk在最后一维均分3块,得到3个(L,B,E)大小的q、k、v.
  14. return linear(q, w, b).chunk(3, dim=-1)
  15. else:
  16. # k=v, encoder-decoder attention,则k、v的linear变换可合并,q单独做
  17. w_q, w_kv = w.split([E, E * 2])
  18. if b is None:
  19. b_q = b_kv = None
  20. else:
  21. b_q, b_kv = b.split([E, E * 2])
  22. return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
  23. else:
  24. # q/k/v各不同,则先将w分3块,再分别做linear
  25. w_q, w_k, w_v = w.chunk(3)
  26. if b is None:
  27. b_q = b_k = b_v = None
  28. else:
  29. b_q, b_k, b_v = b.chunk(3)
  30. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  31. def _in_projection(
  32. q: Tensor,
  33. k: Tensor,
  34. v: Tensor,
  35. w_q: Tensor,
  36. w_k: Tensor,
  37. w_v: Tensor,
  38. b_q: Optional[Tensor] = None,
  39. b_k: Optional[Tensor] = None,
  40. b_v: Optional[Tensor] = None,
  41. ) -> Tuple[Tensor, Tensor, Tensor]:
  42. # embedding维度上q,k,v不同,权重矩阵单独存入,检验输入输出大小是否符合后,分别做linear
  43. Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
  44. assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  45. assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  46. assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  47. assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  48. assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  49. assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  50. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)


  1. # prep attention mask
  2. if attn_mask is not None:
  3. if attn_mask.dtype == torch.uint8:
  4. warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  5. attn_mask = attn_mask.to(torch.bool)
  6. else:
  7. assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
  8. f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
  9. # ensure attn_mask's dim is 3
  10. if attn_mask.dim() == 2:
  11. correct_2d_size = (tgt_len, src_len)
  12. if attn_mask.shape != correct_2d_size:
  13. raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
  14. attn_mask = attn_mask.unsqueeze(0)
  15. elif attn_mask.dim() == 3:
  16. correct_3d_size = (bsz * num_heads, tgt_len, src_len)
  17. if attn_mask.shape != correct_3d_size:
  18. raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
  19. else:
  20. raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
  21. # prep key padding mask
  22. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  23. warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  24. key_padding_mask = key_padding_mask.to(torch.bool)

(4)k,v concat bias_k,bias_v

  1. # add bias along batch dimension (currently second)
  2. if bias_k is not None and bias_v is not None:
  3. # bias_k和bias_v在nn.MultiheadAttention中初始化为(1,1,E)大小的参数
  4. assert static_k is None, "bias cannot be added to static key."
  5. assert static_v is None, "bias cannot be added to static value."
  6. k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) # (S+1, B ,E)
  7. v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
  8. if attn_mask is not None: # attention mask大小(L,S)或(?,L,S)
  9. attn_mask = pad(attn_mask, (0, 1))
  10. # pad操作是在mask最后一维上做padding,左侧一头不添,右侧一头添1。默认用0来pad。
  11. # 维度变为(L,S+1)
  12. if key_padding_mask is not None:
  13. key_padding_mask = pad(key_padding_mask, (0, 1))
  14. else:
  15. assert bias_k is None
  16. assert bias_v is None


  1. # reshape q, k, v for multihead attention and make em batch first
  2. q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
  3. # (B*H, L, E/H)
  4. if static_k is None:
  5. k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  6. #(B*H, S, E/H)或(B*H, S+1, E/H)
  7. else:
  8. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  9. assert static_k.size(0) == bsz * num_heads, \
  10. f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
  11. assert static_k.size(2) == head_dim, \
  12. f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
  13. k = static_k
  14. if static_v is None:
  15. v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  16. # 同k
  17. else:
  18. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  19. assert static_v.size(0) == bsz * num_heads, \
  20. f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
  21. assert static_v.size(2) == head_dim, \
  22. f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
  23. v = static_v

(6)add zero attention


  1. # add zero attention along batch dimension (now first)
  2. if add_zero_attn:
  3. zero_attn_shape = (bsz * num_heads, 1, head_dim)
  4. k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
  5. # k(B*H, S, E/H)->(B*H, S+1, E/H) or ?
  6. v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
  7. if attn_mask is not None:
  8. attn_mask = pad(attn_mask, (0, 1))
  9. if key_padding_mask is not None:
  10. key_padding_mask = pad(key_padding_mask, (0, 1))
  11. # update source sequence length after adjustments
  12. src_len = k.size(1)
  13. # S or S+1 or S+2? 默认情况下add_bias_kv=add_zero_attn=False,此处仍为S。
  14. # merge key padding and attention masks
  15. if key_padding_mask is not None:
  16. assert key_padding_mask.shape == (bsz, src_len), \
  17. f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
  18. key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
  19. expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
  20. # key_mask(B*H, 1, S), attn_mask(B*H,L,S)
  21. if attn_mask is None:
  22. attn_mask = key_padding_mask
  23. elif attn_mask.dtype == torch.bool:
  24. attn_mask = attn_mask.logical_or(key_padding_mask)
  25. # key_mask码掉的区域attn_mask也码掉
  26. else:
  27. attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
  28. # convert mask to float
  29. if attn_mask is not None and attn_mask.dtype == torch.bool:
  30. new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
  31. new_attn_mask.masked_fill_(attn_mask, float("-inf"))
  32. attn_mask = new_attn_mask
  33. # adjust dropout probability,只有训练时设置dropout,推理时不用。
  34. if not training:
  35. dropout_p = 0.0


  1. attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  2. attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
  3. # (L,B,E)
  4. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  5. if need_weights:
  6. # average attention weights over heads
  7. attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
  8. # attn_output_weights(B*H,L,S)->(B,H,L,S)->(B,L,S)
  9. return attn_output, attn_output_weights.sum(dim=1) / num_heads
  10. else:
  11. return attn_output, None


  1. def _scaled_dot_product_attention(
  2. q: Tensor,
  3. k: Tensor,
  4. v: Tensor,
  5. attn_mask: Optional[Tensor] = None,
  6. dropout_p: float = 0.0,
  7. ) -> Tuple[Tensor, Tensor]:
  8. B, Nt, E = q.shape
  9. q = q / math.sqrt(E)
  10. # (B*H, L, E/H) x (B*H, E/H, S) -> attn(B*H, L, S)
  11. attn = torch.bmm(q, k.transpose(-2, -1))
  12. if attn_mask is not None:
  13. attn += attn_mask # attn(B*H, L, S)
  14. # 在attention score上加attn_mask,mask的部分加负无穷大的数,经softmax后为0
  15. attn = softmax(attn, dim=-1) # attn(B*H, L, S)
  16. # 在最后一个维度上做softmax
  17. if dropout_p > 0.0:
  18. attn = dropout(attn, p=dropout_p)
  19. # (B*H, L, S) x (B*H, S, E/H) -> (B*H, L, E/H)
  20. output = torch.bmm(attn, v)
  21. return output, attn

