当前位置:   article > 正文

关于attention的学习(原理+代码)_nondynamicallyquantizablelinear

nondynamicallyquantizablelinear

日志贴,有错拜托评论区指正~

附上李宏毅老师对transformer的讲解:

李宏毅2020机器学习深度学习(完整版)国语_哔哩哔哩_bilibili

关键的原理性问题

1、attention背景(没那么重要):

seq2seq任务(如翻译、对话等),输入一个序列得到一个序列,常用RNN实现,单向RNN可以看到序列当前位置及以前的内容、得到当前结果,双向RNN则可以看遍整个序列得到每一个结果。但是RNN不利于并行化,高度的并行化CNN又不能看遍序列,所以提出了self-attention layer.常见attention functions分为additive attention和dot-product attention,后者即softmax(QKT)V,便于矩阵并行化所以更优。

2、我之前不明所以以为transformer和attention是一个东西...实则不然。

 如图,是transformer的结构,由一个encoder(左)和一个decoder(右)组成,而encoder、decoder内部包含了attention layer,除了attention layer还有FFN(feed forward network,即线性+激活)

把encoder和decoder拆开用,多个encoder组合就是BERT模型,多个decoder组合就是GPT模型。

3、为什么transformer起源于NLP,在CV领域也用得很好?

嗯,这个是我项目答辩时候被老师问住的问题......我想说大白话,transformer在NLP很好用,于是就被拿来CV试试结果人家真的很好用...怕被老师骂,最终讷讷不得言。

Transformer在CV界火的原因是?_公众号机器学习与生成对抗网络的博客-CSDN博客

——也许这个问题的答案可以参考这个博主的文章。

在我导师看来,在当下的CV,attention层就应该和卷积层一样成为基础知识,被当作基本的层来使用。卷积更提取局部特征,attention则加强全局联系,扩大感受野。按照李宏毅老师的说法,CNN是简化版的self-attention,卷积只attend人工划定的感受野;self-attention先找出相关的像素,即自动学出感受野的形状和大小。

怎么将transformer用在CV里面呢?就是把图像也看作序列一样,将H x W x F的图像改为HW x F的维度,即可看作HW个F维向量。对于高像素图,通过打patch的操作将原图切作更小像素级的图像块送入attention层中进行计算。

4、Position Encoding

由于self-attention没有position information,所以加上通过学习得到的PE(维度和embedding的维度一样)来区分query的不同位置,或者说对相对位置进行表达。NLP中这个向量能决定当前词的位置,或者说在一个句子中不同的词之间的距离;CV中就是确定某个像素点的所在行列位置。

代码实现

(我用的是cuda10.2,torch1.10.0,python3.8.12,不同版本调用的代码可能不一样)

  1. multihead_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout)
  2. out, attention_map = multihead_attn(query, key, value, attn_mask, key_padding_mask)

1、torch.nn.MultiheadAttention代码所在位置:torch\nn\modules\activation.py

(1)__init__函数中,必选输入参数有embed_dim和num_heads。

默认情况下q,k,v的embedding维度需要一样,即q,k的维度分别为(L, B, embed_dim)与(S, B, embed_dim),否则应当输入特殊的参数kdim和vdim,即k维度为(S, B, kdim).

无论q,k,v的embedding维度一致与否,都会通过线性变换,变作同样的embed_dim维度,只是线性变换权重要分开放。

  1. class MultiheadAttention(Module):
  2. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
  3. kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
  4. factory_kwargs = {'device': device, 'dtype': dtype}
  5. super(MultiheadAttention, self).__init__()
  6. self.embed_dim = embed_dim
  7. self.kdim = kdim if kdim is not None else embed_dim
  8. self.vdim = vdim if vdim is not None else embed_dim
  9. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  10. self.num_heads = num_heads
  11. self.dropout = dropout
  12. self.batch_first = batch_first
  13. self.head_dim = embed_dim // num_heads
  14. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  15. # embed_dim必须能被num_heads整除
  16. # 以下proj_weight记作A
  17. if self._qkv_same_embed_dim is False:
  18. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
  19. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
  20. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
  21. self.register_parameter('in_proj_weight', None)
  22. else:
  23. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
  24. self.register_parameter('q_proj_weight', None)
  25. self.register_parameter('k_proj_weight', None)
  26. self.register_parameter('v_proj_weight', None)
  27. # 以下in_proj_bias记作b
  28. if bias:
  29. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  30. else:
  31. self.register_parameter('in_proj_bias', None)
  32. # forward过程第一步按公式y = xA^T + b对输入的q,k,v进行线性变换

(2)关于NonDynamicallyQuantizableLinear这个类,仅仅是为了避免在编写不当量化的注意力层脚本时触发一个不明显的错误。这里的out_proj是初始化了embed_dim到embed_dim的线性变化层的权重和偏差,用来对attention最后结果输出前做一次线性变换。

  1. self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  2. if add_bias_kv:
  3. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  4. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  5. else:
  6. self.bias_k = self.bias_v = None
  7. self.add_zero_attn = add_zero_attn
  8. self._reset_parameters()
  9. # 让attention要训的参数好好初始化
  10. def _reset_parameters(self):
  11. if self._qkv_same_embed_dim:
  12. xavier_uniform_(self.in_proj_weight)
  13. else:
  14. xavier_uniform_(self.q_proj_weight)
  15. xavier_uniform_(self.k_proj_weight)
  16. xavier_uniform_(self.v_proj_weight)
  17. if self.in_proj_bias is not None:
  18. constant_(self.in_proj_bias, 0.)
  19. constant_(self.out_proj.bias, 0.)
  20. if self.bias_k is not None:
  21. xavier_normal_(self.bias_k)
  22. if self.bias_v is not None:
  23. xavier_normal_(self.bias_v)
  24. def __setstate__(self, state):
  25. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  26. if '_qkv_same_embed_dim' not in state:
  27. state['_qkv_same_embed_dim'] = True
  28. super(MultiheadAttention, self).__setstate__(state)

(3)forward函数

关于参数和输出的注释。

  1. def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
  2. need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
  3. r"""
  4. Args:
  5. # 如果__init__时batch_first=True,则q、k、v的B在第一个维度,否则如下所示:
  6. query: (L, B, E_q)
  7. key: (S, B, E_k)
  8. value: (S, B, E_v)
  9. key_padding_mask: If specified, a mask of (B, S) indicating which elements within key to ignore for the purpose of attention. Binary and byte masks are supported.
  10. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
  11. need_weights: Default: True.
  12. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S) or (B\cdot\text{num\_heads}, L, S).
  13. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  14. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
  15. Outputs:
  16. - **attn_output** - Attention outputs (L, B, E)或(B, L, E)
  17. - **attn_output_weights** - Attention output weights (B, L, S) when need_weights=True
  18. """

具体实现依靠F.multi_head_attention_forward

  1. if self.batch_first:
  2. query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
  3. # 保证q,k,v的B一直在第二个维度
  4. # 将前面初始化好的权重偏差以及Q/K/V、mask传入更底层的函数,得到输出的最终结果和中间结果
  5. if not self._qkv_same_embed_dim:
  6. # 如果要用embedding维度不一样的k/v,需要令F.multi_head_attention_forward
  7. # 的输入参数use_separate_proj_weight=True
  8. attn_output, attn_output_weights = F.multi_head_attention_forward(
  9. query, key, value, self.embed_dim, self.num_heads,
  10. self.in_proj_weight, self.in_proj_bias,
  11. self.bias_k, self.bias_v, self.add_zero_attn,
  12. self.dropout, self.out_proj.weight, self.out_proj.bias,
  13. training=self.training,
  14. key_padding_mask=key_padding_mask, need_weights=need_weights,
  15. attn_mask=attn_mask, use_separate_proj_weight=True,
  16. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  17. v_proj_weight=self.v_proj_weight)
  18. else:
  19. attn_output, attn_output_weights = F.multi_head_attention_forward(
  20. query, key, value, self.embed_dim, self.num_heads,
  21. self.in_proj_weight, self.in_proj_bias,
  22. self.bias_k, self.bias_v, self.add_zero_attn,
  23. self.dropout, self.out_proj.weight, self.out_proj.bias,
  24. training=self.training,
  25. key_padding_mask=key_padding_mask, need_weights=need_weights,
  26. attn_mask=attn_mask)
  27. if self.batch_first:
  28. return attn_output.transpose(1, 0), attn_output_weights
  29. else:
  30. return attn_output, attn_output_weights

2、F.multi_head_attention_forward的实现代码位于torch\nn\functional.py中

(1)第一段无重点

  1. def multi_head_attention_forward(
  2. query: Tensor,
  3. key: Tensor,
  4. value: Tensor,
  5. embed_dim_to_check: int,
  6. num_heads: int,
  7. in_proj_weight: Tensor,
  8. in_proj_bias: Optional[Tensor],
  9. bias_k: Optional[Tensor],
  10. bias_v: Optional[Tensor],
  11. add_zero_attn: bool,
  12. dropout_p: float,
  13. out_proj_weight: Tensor,
  14. out_proj_bias: Optional[Tensor],
  15. training: bool = True,
  16. key_padding_mask: Optional[Tensor] = None,
  17. need_weights: bool = True,
  18. attn_mask: Optional[Tensor] = None,
  19. use_separate_proj_weight: bool = False,
  20. q_proj_weight: Optional[Tensor] = None,
  21. k_proj_weight: Optional[Tensor] = None,
  22. v_proj_weight: Optional[Tensor] = None,
  23. static_k: Optional[Tensor] = None,
  24. static_v: Optional[Tensor] = None,
  25. ) -> Tuple[Tensor, Optional[Tensor]]:
  26. tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
  27. # 这个看不懂,暂且跳过不影响
  28. if has_torch_function(tens_ops):
  29. return handle_torch_function(
  30. multi_head_attention_forward,
  31. tens_ops,
  32. query,
  33. key,
  34. value,
  35. embed_dim_to_check,
  36. num_heads,
  37. in_proj_weight,
  38. in_proj_bias,
  39. bias_k,
  40. bias_v,
  41. add_zero_attn,
  42. dropout_p,
  43. out_proj_weight,
  44. out_proj_bias,
  45. training=training,
  46. key_padding_mask=key_padding_mask,
  47. need_weights=need_weights,
  48. attn_mask=attn_mask,
  49. use_separate_proj_weight=use_separate_proj_weight,
  50. q_proj_weight=q_proj_weight,
  51. k_proj_weight=k_proj_weight,
  52. v_proj_weight=v_proj_weight,
  53. static_k=static_k,
  54. static_v=static_v,
  55. )
  56. # set up shape vars
  57. tgt_len, bsz, embed_dim = query.shape
  58. src_len, _, _ = key.shape
  59. assert embed_dim == embed_dim_to_check, \
  60. f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
  61. if isinstance(embed_dim, torch.Tensor):
  62. # embed_dim can be a tensor when JIT tracing
  63. head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
  64. else:
  65. head_dim = embed_dim // num_heads
  66. assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
  67. # 这里再次规定了embed_dim必需能被num_heads整除,否则会报错。

(2)in_projection:将q,k,v都做一次线性变换,无论先前如何,变换后都是同样的emdding维度。

  1. # use_separate_proj_weight=True时,是不同embedding维度的kv输入
  2. if use_separate_proj_weight:
  3. # allow MHA to have different embedding dimensions when separate projection weights are used
  4. assert key.shape[:2] == value.shape[:2], \
  5. f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
  6. # kv的sequence和batch维度必须一致,即前两个维度必须都是(S,B)
  7. else:
  8. assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
  9. # 计算in-projection
  10. if not use_separate_proj_weight:
  11. q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  12. else:
  13. assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
  14. assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
  15. assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
  16. if in_proj_bias is None:
  17. b_q = b_k = b_v = None
  18. else:
  19. b_q, b_k, b_v = in_proj_bias.chunk(3)
  20. q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

(2.1)_in_projection_packed和 _in_projection函数

见注释,前者是对embedding维度相同的q/k/v进行线性变换;后者是对维度不同的做。

  1. def _in_projection_packed(
  2. q: Tensor,
  3. k: Tensor,
  4. v: Tensor,
  5. w: Tensor,
  6. b: Optional[Tensor] = None,
  7. ) -> List[Tensor]:
  8. E = q.size(-1)
  9. if k is v:
  10. if q is k:
  11. # q=k=v,做的是self-attention,可以直接将(3E,E)的权重矩阵w与(L,B,E)的q送入linear,
  12. # linear的做法是q*w^T+b,所以得到结果(L,B,3E)。
  13. # 再用chunk在最后一维均分3块,得到3个(L,B,E)大小的q、k、v.
  14. return linear(q, w, b).chunk(3, dim=-1)
  15. else:
  16. # k=v, encoder-decoder attention,则k、v的linear变换可合并,q单独做
  17. w_q, w_kv = w.split([E, E * 2])
  18. if b is None:
  19. b_q = b_kv = None
  20. else:
  21. b_q, b_kv = b.split([E, E * 2])
  22. return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
  23. else:
  24. # q/k/v各不同,则先将w分3块,再分别做linear
  25. w_q, w_k, w_v = w.chunk(3)
  26. if b is None:
  27. b_q = b_k = b_v = None
  28. else:
  29. b_q, b_k, b_v = b.chunk(3)
  30. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  31. def _in_projection(
  32. q: Tensor,
  33. k: Tensor,
  34. v: Tensor,
  35. w_q: Tensor,
  36. w_k: Tensor,
  37. w_v: Tensor,
  38. b_q: Optional[Tensor] = None,
  39. b_k: Optional[Tensor] = None,
  40. b_v: Optional[Tensor] = None,
  41. ) -> Tuple[Tensor, Tensor, Tensor]:
  42. # embedding维度上q,k,v不同,权重矩阵单独存入,检验输入输出大小是否符合后,分别做linear
  43. Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
  44. assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  45. assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  46. assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  47. assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  48. assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  49. assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  50. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

(3)检验mask输入符合数据类型和维度大小的要求否

  1. # prep attention mask
  2. if attn_mask is not None:
  3. if attn_mask.dtype == torch.uint8:
  4. warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  5. attn_mask = attn_mask.to(torch.bool)
  6. else:
  7. assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
  8. f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
  9. # ensure attn_mask's dim is 3
  10. if attn_mask.dim() == 2:
  11. correct_2d_size = (tgt_len, src_len)
  12. if attn_mask.shape != correct_2d_size:
  13. raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
  14. attn_mask = attn_mask.unsqueeze(0)
  15. elif attn_mask.dim() == 3:
  16. correct_3d_size = (bsz * num_heads, tgt_len, src_len)
  17. if attn_mask.shape != correct_3d_size:
  18. raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
  19. else:
  20. raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
  21. # prep key padding mask
  22. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  23. warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  24. key_padding_mask = key_padding_mask.to(torch.bool)

(4)k,v concat bias_k,bias_v

  1. # add bias along batch dimension (currently second)
  2. if bias_k is not None and bias_v is not None:
  3. # bias_k和bias_v在nn.MultiheadAttention中初始化为(1,1,E)大小的参数
  4. assert static_k is None, "bias cannot be added to static key."
  5. assert static_v is None, "bias cannot be added to static value."
  6. k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) # (S+1, B ,E)
  7. v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
  8. if attn_mask is not None: # attention mask大小(L,S)或(?,L,S)
  9. attn_mask = pad(attn_mask, (0, 1))
  10. # pad操作是在mask最后一维上做padding,左侧一头不添,右侧一头添1。默认用0来pad。
  11. # 维度变为(L,S+1)
  12. if key_padding_mask is not None:
  13. key_padding_mask = pad(key_padding_mask, (0, 1))
  14. else:
  15. assert bias_k is None
  16. assert bias_v is None

(5)根据head数目改变q,k,v维度

  1. # reshape q, k, v for multihead attention and make em batch first
  2. q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
  3. # (B*H, L, E/H)
  4. if static_k is None:
  5. k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  6. #(B*H, S, E/H)或(B*H, S+1, E/H)
  7. else:
  8. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  9. assert static_k.size(0) == bsz * num_heads, \
  10. f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
  11. assert static_k.size(2) == head_dim, \
  12. f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
  13. k = static_k
  14. if static_v is None:
  15. v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  16. # 同k
  17. else:
  18. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  19. assert static_v.size(0) == bsz * num_heads, \
  20. f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
  21. assert static_v.size(2) == head_dim, \
  22. f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
  23. v = static_v

(6)add zero attention

对mask做一些处理

  1. # add zero attention along batch dimension (now first)
  2. if add_zero_attn:
  3. zero_attn_shape = (bsz * num_heads, 1, head_dim)
  4. k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
  5. # k(B*H, S, E/H)->(B*H, S+1, E/H) or ?
  6. v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
  7. if attn_mask is not None:
  8. attn_mask = pad(attn_mask, (0, 1))
  9. if key_padding_mask is not None:
  10. key_padding_mask = pad(key_padding_mask, (0, 1))
  11. # update source sequence length after adjustments
  12. src_len = k.size(1)
  13. # S or S+1 or S+2? 默认情况下add_bias_kv=add_zero_attn=False,此处仍为S。
  14. # merge key padding and attention masks
  15. if key_padding_mask is not None:
  16. assert key_padding_mask.shape == (bsz, src_len), \
  17. f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
  18. key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
  19. expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
  20. # key_mask(B*H, 1, S), attn_mask(B*H,L,S)
  21. if attn_mask is None:
  22. attn_mask = key_padding_mask
  23. elif attn_mask.dtype == torch.bool:
  24. attn_mask = attn_mask.logical_or(key_padding_mask)
  25. # key_mask码掉的区域attn_mask也码掉
  26. else:
  27. attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
  28. # convert mask to float
  29. if attn_mask is not None and attn_mask.dtype == torch.bool:
  30. new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
  31. new_attn_mask.masked_fill_(attn_mask, float("-inf"))
  32. attn_mask = new_attn_mask
  33. # adjust dropout probability,只有训练时设置dropout,推理时不用。
  34. if not training:
  35. dropout_p = 0.0

(7)用_scaled_dot_product_attention做attention

  1. attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  2. attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
  3. # (L,B,E)
  4. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  5. if need_weights:
  6. # average attention weights over heads
  7. attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
  8. # attn_output_weights(B*H,L,S)->(B,H,L,S)->(B,L,S)
  9. return attn_output, attn_output_weights.sum(dim=1) / num_heads
  10. else:
  11. return attn_output, None

(7.1)_scaled_dot_product_attention

  1. def _scaled_dot_product_attention(
  2. q: Tensor,
  3. k: Tensor,
  4. v: Tensor,
  5. attn_mask: Optional[Tensor] = None,
  6. dropout_p: float = 0.0,
  7. ) -> Tuple[Tensor, Tensor]:
  8. B, Nt, E = q.shape
  9. q = q / math.sqrt(E)
  10. # (B*H, L, E/H) x (B*H, E/H, S) -> attn(B*H, L, S)
  11. attn = torch.bmm(q, k.transpose(-2, -1))
  12. if attn_mask is not None:
  13. attn += attn_mask # attn(B*H, L, S)
  14. # 在attention score上加attn_mask,mask的部分加负无穷大的数,经softmax后为0
  15. attn = softmax(attn, dim=-1) # attn(B*H, L, S)
  16. # 在最后一个维度上做softmax
  17. if dropout_p > 0.0:
  18. attn = dropout(attn, p=dropout_p)
  19. # (B*H, L, S) x (B*H, S, E/H) -> (B*H, L, E/H)
  20. output = torch.bmm(attn, v)
  21. return output, attn

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/366830?site
推荐阅读
相关标签
  

闽ICP备14008679号