当前位置:   article > 正文

【transformer】自注意力源码解读和复杂度计算

【transformer】自注意力源码解读和复杂度计算

Self-attention

1

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。这里的注意力机制是通过计算查询向量 Q Q Q和键向量 K K K之间的相似性,来为值向量 V V V分配不同的权重。如果两个向量越相似,则它们之间的权重应该越大,反之则越小。

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)  # 获取文本嵌入维度大小
    # 按照注意力机制的公式计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    # 是否使用掩码
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # 使用softmax对最后一个维度获得注意力张量
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # 注意力张量与value相乘得到query的注意力表示
    return torch.matmul(p_attn, value), p_attn
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

一个形状为 N × M N\times M N×M 的矩阵,与另一个形状为 M × P M\times P M×P的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为 O ( N M P ) O(NMP) O(NMP)

Self-attention的公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。
Self-Attention的计算复杂度主要来自三个方面:查询矩阵、键矩阵和值矩阵的乘积、softmax 的计算、以及输出向量和值的加权平均。
对于一个由n个单词组成的输入序列,假设有d个维度的特征,那么查询矩阵、键矩阵和值矩阵的维度都将是 n × d。

  • 对于查询矩阵 Q 和键矩阵 K 的点积, n × d n\times d n×d d × n d\times n d×n计算复杂度是 O ( n 2 d ) O(n^2d) O(n2d)
  • 每行 softmax 的计算,计算复杂度为 O ( n ) O(n) O(n),对n行做softmax,复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 对于值矩阵 V (维度 n × d n\times d n×d)和 softmax 后的结果(维度 n × n n\times n n×n)进行点积,得到每个查询向量的加权平均值,复杂度是 O ( n 2 d ) O(n^2d) O(n2d)

因此,总的计算复杂度是 O ( n 2 d ) + O ( n 2 ) + O ( n 2 d ) ≃ O ( n 2 d ) O(n^2d) + O(n^2) + O(n^2d) \simeq O(n^2d) O(n2d)+O(n2)+O(n2d)O(n2d)
由于这个复杂度是关于输入序列长度n的平方级别,因此Self-Attention在处理长序列时可能会面临计算上的挑战。

多头注意力

2
多头注意力的计算公式如下:
MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … ,  head  h ) W O  where   head  i = A ( Q W i Q , K W i K , V W i V ) MultiHead(Q,K,V)=Concat(head1,, head h)WO where  head i=A(QWQi,KWKi,VWVi)

MultiHead(Q,K,V) where  head i=Concat(head1,, head h)WO=A(QWiQ,KWiK,VWiV)其中, Q , K , V Q,K,V Q,K,V 分别表示查询、键和值, h h h 表示头数, h e a d i head_i headi 表示第 i i i 个注意力头, W O W^O WO 表示输出层的权重矩阵。

# 用于深度拷贝的copy工具包
import copy

# 首先需要定义克隆函数, 因为在多头注意力机制的实现中, 用到多个结构相同的线性层.
# 我们将使用clone函数将他们一同初始化在一个网络层列表对象中. 之后的结构中也会用到该函数.
def clones(module, N):
    """用于生成相同网络层的克隆函数, 它的参数module表示要克隆的目标网络层, N代表需要克隆的数量"""
    # 在函数中, 我们通过for循环对module进行N次深度拷贝, 使其每个module成为独立的层,
    # 然后将其放在nn.ModuleList类型的列表中存放.
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# 我们使用一个类来实现多头注意力机制的处理
class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout=0.1):
        """在类的初始化时, 会传入三个参数,head代表头数,embedding_dim代表词嵌入的维度, 
           dropout代表进行dropout操作时置0比率,默认是0.1."""
        super(MultiHeadedAttention, self).__init__()

        # 在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,
        # 这是因为我们之后要给每个头分配等量的词特征.也就是embedding_dim/head个.
        assert embedding_dim % head == 0

        # 得到每个头获得的分割词向量维度d_k
        self.d_k = embedding_dim // head

        # 传入头数h
        self.head = head

        # 然后获得线性层对象,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用clones函数克隆四个,
        # 为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个.
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)

        # self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None.
        self.attn = None

        # 最后就是一个self.dropout对象,它通过nn中的Dropout实例化而来,置0比率为传进来的参数dropout.
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """前向逻辑函数, 它的输入参数有四个,前三个就是注意力机制需要的Q, K, V,
           最后一个是注意力机制中可能需要的mask掩码张量,默认是None. """

        # 如果存在掩码张量mask
        if mask is not None:
            # 使用unsqueeze拓展维度
            mask = mask.unsqueeze(0)

        # 接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本.
        batch_size = query.size(0)

        # 之后就进入多头处理环节
        # 首先利用zip将输入QKV与三个线性层组到一起,然后使用for循环,将输入QKV分别传到线性层中,
        # 做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结果进行维度重塑,多加了一个维度h,代表头数,
        # 这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,
        # 计算机会根据这种变换自动计算这里的值.然后对第二维和第三维进行转置操作,
        # 为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,
        # 从attention函数中可以看到,利用的是原始输入的倒数第一和第二维.这样我们就得到了每个头的输入.
        query, key, value = \
           [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
            for model, x in zip(self.linears, (query, key, value))]

        # 得到每个头的输入后,接下来就是将他们传入到attention中,
        # 这里直接调用我们之前实现的attention函数.同时也将mask和dropout传入其中.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,
        # 因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法,
        # 这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,
        # 所以,下一步就是使用view重塑形状,变成和输入形状相同.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head * self.d_k)

        # 最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出.
        return self.linears[-1](x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

在多头注意力中,假设有 h h h 个头,每个头的查询、键和值的维度是 d k d_k dk d k d_k dk d v d_v dv,一般情况 d q = d k = d v = d h d_q=d_k=d_v=\frac{d}{h} dq=dk=dv=hd, 输入序列的长度为 N N N

  • 输入线性映射的复杂度: n × d n\times d n×d d × d h d \times \frac{d}{h} d×hd,计算复杂度是 O ( n d 2 h ) O(\frac{nd^2 }{h}) O(hnd2)
  • 注意力计算:输入线性映射后的维度 n × d h n \times \frac{d}{h} n×hd n × d h n \times \frac{d}{h} n×hd d h × n \frac{d}{h}\times n hd×n计算复杂度是 O ( n 2 d h ) O(n^2\frac{d}{h}) O(n2hd)
  • 输出线性映射: 多个头的结果concat成一个 n × d n\times d n×d矩阵, n × d n\times d n×d d × d d \times d d×d,计算复杂度是 O ( n d 2 ) O(nd^2) O(nd2)

总时间复杂度 O ( n d 2 h + n 2 d h + n d 2 ) O(\frac{nd^2}{h}+n^2\frac{d}{h}+nd^2) O(hnd2+n2hd+nd2)


参考:
传智博客-Transformer

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/368403
推荐阅读
相关标签
  

闽ICP备14008679号