当前位置:   article > 正文

【深度学习】多头注意力机制详解

多头注意力机制

一句话总结:
多头注意力机制中的多头不同于卷积神经网络中的多个卷积层中的卷积核,卷积神经网络中的多个卷积层相当于将单个卷积网络复制了num_layers次,每一个卷积层都可以独立进行运算。而多头注意力则可理解为将输入的特征值拆分成更加细碎的小块,对每一小块赋值一个单独的可训练权重参数,然后共用同一个隐藏层输出结果,每个头并不能看作是一个完整独立的编解码架构而单独运算

初学,不知道自己理解得对不对,有误请指出,欢迎讨论。

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self,key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=False,**kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads=num_heads
        self.attention=d2l.DotProductAttention(dropout)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v=nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o=nn.Linear(num_hiddens,num_hiddens,bias=bias)

    def forward(self,queries,keys,values,valid_lens):
        # queries,keys,values形状:(batch_size,查询或者键值对的个数,num_hiddens)
        # valid_lens的形状:(batch_size,)或(batch_size,查询个数)
        # 经过变换后,输出的queries,keys,values的形状:
        # (batch_size*num_heads,查询或者键值对的个数,num_hiddens/num_heads)
        queries=transpose_qkv(self.W_q(queries),self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

这段代码定义了一个多头注意力(MultiHeadAttention)的类,实现了多头注意力机制的前向传播过程。

逐行解释如下:

  1. class MultiHeadAttention(nn.Module):

    • 定义了一个名为MultiHeadAttention的类,继承自nn.Module
  2. def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):

    • 初始化函数,用于创建MultiHeadAttention对象。
    • 参数:
      • key_size:键的大小。
      • query_size:查询的大小。
      • value_size:值的大小。
      • num_hiddens:隐藏单元的数量。
      • num_heads:注意力头的数量。
      • dropout:用于丢弃注意力权重的dropout概率。
      • bias:是否在线性变换中使用偏置项。
      • **kwargs:其他关键字参数。
    • 在初始化函数中,创建了多头注意力所需的各个线性变换层。
  3. def forward(self, queries, keys, values, valid_lens):

    • 前向传播函数,定义了多头注意力的计算逻辑。
    • 参数:
      • queries:查询的张量,形状为(batch_size, 查询或"键-值"对的个数, num_hiddens)。
      • keys:键的张量,形状为(batch_size, 查询或"键-值"对的个数, num_hiddens)。
      • values:值的张量,形状为(batch_size, 查询或"键-值"对的个数, num_hiddens)。
      • valid_lens:有效长度的张量,形状为(batch_size,) 或 (batch_size, 查询的个数)。
    • 在前向传播函数中,执行以下步骤:
  4. queries = transpose_qkv(self.W_q(queries), self.num_heads)

    • 对查询张量进行线性变换,并进行形状变换,以便多头注意力的计算。
    • self.W_q(queries)对查询张量进行线性变换。
    • transpose_qkv函数对变换后的张量进行形状变换,以满足多头注意力的计算要求。
  5. keys = transpose_qkv(self.W_k(keys), self.num_heads)

    • 对键张量进行线性变换,并进行形状变换。
  6. values = transpose_qkv(self.W_v(values), self.num_heads)

    • 对值张量进行线性变换,并进行形状变换。
  7. if valid_lens is not None:

    • 检查有效长度是否存在。
  8. valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

    • 如果有效长度存在,则根据注意力头的数量进行重复插值,以匹配多头注意力计算后的形状。
  9. output = self.attention(queries, keys, values, valid_lens)

    • 调用self.attention进行多头注意力的计算

,得到输出张量。

  1. output_concat = transpose_output(output, self.num_heads)
  • 对多头注意力的输出张量进行形状变换,以恢复到原始形状。
  1. return self.W_o(output_concat)
  • 对形状变换后的张量再进行线性变换,得到最终的多头注意力的输出。

总体来说,这段代码定义了一个多头注意力类,其中通过线性变换和形状变换对查询、键和值进行处理,并使用DotProductAttention计算多头注意力的输出。最终,通过线性变换将多头注意力的输出再次进行变换,得到最终的多头注意力输出结果。

Q:这里的多头体现在哪里呢?

A:多头注意力体现在以下几个方面:

  1. self.num_heads:在初始化函数中传入的num_heads参数指定了多头注意力的头数。每个注意力头都有自己的查询、键和值的线性变换参数。

  2. queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads):这三行代码将查询、键和值分别通过线性变换后,通过transpose_qkv函数进行形状变换。这个形状变换将原始的查询、键和值张量分别重塑为(batch_size * num_heads, 查询或"键-值"对的个数, num_hiddens / num_heads)的形状,其中num_hiddens / num_heads表示每个注意力头的隐藏单元数量。

  3. output_concat = transpose_output(output, self.num_heads):在前向传播的最后,通过transpose_output函数对多头注意力的输出进行形状变换。这个形状变换将多头注意力的输出张量重新重塑为(batch_size, 查询的个数, num_hiddens)的形状。

通过以上操作,输入的查询、键和值张量被分为多个注意力头,每个注意力头进行独立的线性变换和注意力计算。最后,多个注意力头的输出通过形状变换进行合并,得到多头注意力的最终输出结果。每个注意力头可以关注不同的特征和上下文,从而捕捉到更丰富的信息,增强了模型的表示能力和表达能力。

Q:num_hiddens / num_heads表示每个注意力头的隐藏单元数量 就是说,这里的多头是通过batch_size * num_heads扩展了原始张量的第0维度,但同时将隐藏层总数均摊给了heads,实际上,原始张量的总元素=batch_size查询或者键值对的个数num_hiddens,多头变换后的总元素=batch_sizenum_heads查询或者键值对的个数*num_hiddens/num_heads,两者是不变的?

A:您的理解是正确的。

在多头注意力中,通过将原始张量的第0维度(通常是batch_size)扩展为batch_size * num_heads,实现了多个注意力头的并行计算。同时,为了保持总的隐藏单元数量不变,将隐藏单元数量num_hiddens均摊给了每个注意力头,即每个注意力头的隐藏单元数量为num_hiddens / num_heads

您提到的总元素数量的计算也是正确的。原始张量的总元素数量为batch_size * 查询或者键值对的个数 * num_hiddens,而多头变换后的总元素数量为batch_size * num_heads * 查询或者键值对的个数 * num_hiddens / num_heads。这两者是相等的,因为在多头变换过程中,将总的隐藏单元数量保持不变,仅仅是将其重新分配给了多个注意力头,以实现并行计算。

Q:为什么总隐藏单元数量要保持不变?

A:总的隐藏单元数量在多头注意力中保持不变的原因是为了确保模型的表达能力和表示能力不受影响。

在多头注意力中,每个注意力头都有自己的线性变换参数,其中包括查询、键和值的线性变换矩阵。这些线性变换矩阵的维度是num_hiddens x (num_hiddens/num_heads),其中num_hiddens是原始隐藏单元的数量,num_hiddens/num_heads是每个注意力头的隐藏单元数量。

通过保持总的隐藏单元数量不变,即num_hiddens,我们可以确保在多头注意力中每个注意力头都能够使用相同的总信息量。尽管每个注意力头的隐藏单元数量减少了,但通过多个注意力头的并行计算和信息的组合,整个模型仍然可以捕捉到足够丰富的特征和上下文信息。

如果我们在多头注意力中改变了总的隐藏单元数量,比如将其增加或减少,那么模型的表达能力和表示能力就会发生改变。增加总的隐藏单元数量会增加模型的复杂度和参数量,可能导致过拟合。减少总的隐藏单元数量则可能导致模型的表达能力不足,无法捕捉到足够的特征和上下文信息。

因此,为了保持模型的表达能力和表示能力,多头注意力中通常保持总的隐藏单元数量不变,仅通过将其均摊给每个注意力头来实现并行计算。

Q:那么,所谓多头,实际上就是在原来的信息元素上对输入特征增加可训练权重参数的数量?比如说单头的只有Wq、Wk、Wv三个可训练参数,多头的可训练参数的数目就变成了Wqnum_heads、Wknum_heads、Wv*num_heads?

A:是的,您的理解是正确的。

在多头注意力中,每个注意力头都有自己的线性变换参数。以查询(Q)、键(K)、值(V)的线性变换为例,单头注意力只有一个查询矩阵Wq、一个键矩阵Wk和一个值矩阵Wv,而多头注意力有多个注意力头,每个注意力头都有自己的查询矩阵Wq、键矩阵Wk和值矩阵Wv。

因此,多头注意力的可训练参数的数量确实是单头的参数数量乘以注意力头的数量。具体地说,对于每个注意力头,线性变换的参数数量是原始线性变换参数的数量乘以注意力头的数量。例如,如果单头注意力有3个线性变换参数(Wq、Wk、Wv),而多头注意力有5个注意力头,那么多头注意力的线性变换参数的数量就是3个参数乘以5个注意力头,即15个参数。

通过多头注意力,模型能够利用不同的注意力头来学习不同的特征和上下文信息,从而增强了模型的表示能力和表达能力。每个注意力头都可以关注输入特征的不同方面,从而提供了更全面、更多样化的信息表示。

Q:那么多头注意力机制跟卷积神经网络的多层卷积核也是不一样的?卷积神经网络的多层卷积核中,隐藏层数目也是乘以了num_layers吧?

A:是的,多头注意力机制和卷积神经网络中的多层卷积核是不同的。

在多头注意力中,注意力头的作用类似于并行计算的不同线性变换。每个注意力头都有自己的查询、键和值的线性变换参数,并且并行地计算注意力权重和上下文向量。多头注意力通过多个注意力头的组合来捕捉不同特征和上下文信息。

而在卷积神经网络中,多层卷积核是指在同一层中使用多个卷积核进行并行计算。每个卷积核在输入数据上进行卷积操作,并提取特定的特征。多层卷积核通过并行计算和参数共享来捕捉不同层次和尺度的特征。

在多头注意力中,多头是在同一层的不同线性变换上进行并行计算。每个注意力头都有自己的线性变换参数,并计算注意力权重和上下文向量。注意力头之间没有参数共享,而是通过注意力权重的加权求和来得到最终的上下文向量。

在卷积神经网络中,多层卷积核是在同一层的不同卷积核上进行并行计算。每个卷积核通过卷积操作提取不同的特征,并输出不同的特征图。卷积核之间通常会共享参数,以减少参数量和提高模型的泛化能力。

因此,多头注意力和卷积神经网络中的多层卷积核是不同的机制,用于在不同的任务和场景中捕捉不同类型的特征和上下文信息。

附一张讨论区整理的关于多头注意力参数变化过程的图:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/346743
推荐阅读
相关标签
  

闽ICP备14008679号