赞
踩
接前序的四篇Transformer解读博客,补充说明第四次博客中MltiheadAttention类的数据源码处理。
Transformer实现以及Pytorch源码解读(四)-Encoder层
\site-packages\torch\nn\modules\activation.py
\site-packages\torch\nn\functional.py
用到\site-packages\torch\nn\modules\activation.py的类:
MultiheadAttention类:
\site-packages\torch\nn\functional.py:
_in_projection_packed函数
_scaled_dot_product_attention函数
multi_head_attention_forward函数
第一步:
根据以下的参数将MultiheadAttention类初始化。主要接受的参数,词向量维度,和头的数量
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim #是一个binary变量,表示k,q,v的维度是否一样 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim # print("=========================_qkv_same_embed_dim=:",self._qkv_same_embed_dim) self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) # print(self.in_proj_weight.shape) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) else: self.register_parameter('in_proj_bias', None) self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) if add_bias_kv: self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters()
其中in_proj_weight和in_proj_bias为初始化的权重和偏置项。通过参数_qkv_same_embed_dim判断是否为自注意力,如果是自注意力的话将q进行扩充3倍处理。
第二步
初始化权重和偏置项:
在_reset_parameters()函数中进行初始化,给权重和偏置项中的每个位置随机填充-a到a之间的一个数字,a的计算用到的以下的公式:
a
=
gain
×
6
fan_in
+
fan_out
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
a=gain×fan_in+fan_out6
以上公式的实现如下
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-a, a)` where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional `torch.Tensor` gain: an optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation return _no_grad_uniform_(tensor, -a, a)
fan_in 和fan_out的计算是根据输入tensor的维度确定的:
def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.dim() if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim() > 2: # math.prod is not always available, accumulate the product manually # we could use functools.reduce but that is not supported by TorchScript for s in tensor.shape[2:]: receptive_field_size *= s fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out
第三步
在MultiheadAttention类的forward中进行每个batch的计算。
简化来看进行的是如下的操作:
(1)三个参数分别经过一个全连接层
if not use_separate_proj_weight:
#三个参数与in_porj_weight相乘。
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
def _in_projection_packed( q: Tensor, k: Tensor, v: Tensor, w: Tensor, b: Optional[Tensor] = None, ) -> List[Tensor]: r""" Performs the in-projection step of the attention operation, using packed weights. Output is a triple containing projection tensors for query, key and value. Args: q, k, v: query, key and value tensors to be projected. For self-attention, these are typically the same tensor; for encoder-decoder attention, k and v are typically the same tensor. (We take advantage of these identities for performance if they are present.) Regardless, q, k and v must share a common embedding dimension; otherwise their shapes may vary. w: projection weights for q, k and v, packed into a single tensor. Weights are packed along dimension 0, in q, k, v order. b: optional projection biases for q, k and v, packed into a single tensor in q, k, v order. Shape: Inputs: - q: :math:`(..., E)` where E is the embedding dimension - k: :math:`(..., E)` where E is the embedding dimension - v: :math:`(..., E)` where E is the embedding dimension - w: :math:`(E * 3, E)` where E is the embedding dimension - b: :math:`E * 3` where E is the embedding dimension Output: - in output list :math:`[q', k', v']`, each output tensor will have the same shape as the corresponding input tensor. """ E = q.size(-1) if k is v: if q is k: # print("=========q:",q.shape) # print("=========w:",w.shape) return linear(q, w, b).chunk(3, dim=-1) else: # encoder-decoder attention w_q, w_kv = w.split([E, E * 2]) if b is None: b_q = b_kv = None else: b_q, b_kv = b.split([E, E * 2]) return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1) else: w_q, w_k, w_v = w.chunk(3) if b is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
(2)三参数在bachsize维度根据头数扩充
多头就是在这里进行工作的
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
(3)三参数顺序列相乘
获得注意力向量和最终结果。注意这里torch.bmm的使用。
def _scaled_dot_product_attention( q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, ) -> Tuple[Tensor, Tensor]: r""" Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. Returns a tensor pair containing attended values and attention weights. Args: q, k, v: query, key and value tensors. See Shape section for shape details. attn_mask: optional tensor containing mask values to be added to calculated attention. May be 2D or 3D; see Shape section for details. dropout_p: dropout probability. If greater than 0.0, dropout is applied. Shape: - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length, and E is embedding dimension. - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, and E is embedding dimension. - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, and E is embedding dimension. - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of shape :math:`(Nt, Ns)`. - Output: attention values have shape :math:`(B, Nt, E)`; attention weights have shape :math:`(B, Nt, Ns)` """ B, Nt, E = q.shape q = q / math.sqrt(E) # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) attn = torch.bmm(q, k.transpose(-2, -1)) if attn_mask is not None: attn += attn_mask attn = softmax(attn, dim=-1) if dropout_p > 0.0: attn = dropout(attn, p=dropout_p) # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) output = torch.bmm(attn, v) return output, attn
源码总对于num_head的处理有代码冗余的情况。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。