当前位置:   article > 正文

注意力机制-多头-自注意力-自注意力与LSTM结合_lstm 多头注意力机制

lstm 多头注意力机制

注意力机制的作业

同时我发现了代码的三个bug,第一个是自注意力机制中没有加入位置编码

第二个是LSTM与注意力机制的融合中,没有加入位置编码

第三个是自定义的汇聚层中,出现了错误,在原文进行了标注

多头注意力

根据之前学到的知识,你可能发现,在上面的注意力打分函数中有一个问题:没有可以学习的参数。所以,我们引入"多头"的概念,将输入向量经过不同的线性映射,得到不同的结果作为“查询”、“键”和“值”。线性映射是可以学习的映射矩阵。

为了能够使多个头并行计算, 先定义下面的两个转置函数。具体来说,transpose_output函数反转了transpose_qkv函数的操作。

# 导入必要的库
import os
import math
import torch 
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import data
import re
import collections
import random
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
# 上周实现的模块
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # print(nn.functional.softmax(X.reshape(shape), dim=-1))
        return nn.functional.softmax(X.reshape(shape), dim=-1)
    
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        # queries,keys,values的形状: (batch_size*num_heads,q-k-v对的个数,num_hiddens/num_heads)
        
        #除最后一个维度数 
        d = queries.shape[-1] 

        # 缩放点积
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        # scores的形状(batchsize*num_heads,q的个数,k的个数)

        # 计算权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # self.attention_weights的形状(batchsize*num_heads,q的个数,k-v的个数)

        # 求结果
        return torch.bmm(self.dropout(self.attention_weights), values),scores
        # (batchsize*num_heads,q的个数,value的特征维度-num_hiddens/num_heads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
def transpose_qkv(X, num_heads):
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):    #(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  
    
    X = X.permute(0, 2, 1, 3)# (batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)

    return X.reshape(X.shape[0], X.shape[1], -1)  # (batch_size,查询或者“键-值”对的个数, num_heads*num_hiddens/num_heads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状: (batch_size,)或(batch_size,查询的个数)
        
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        # 经过变换后,输出的queries,keys,values的形状: (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output,score = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        
        return self.W_o(output_concat),score.reshape(-1, self.num_heads,score.shape[1], score.shape[2]).permute(1,0,2,3) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

多头注意力模块输出的形状是 (batch_size,num_queries,num_hiddens)。

num_hiddens, num_heads = 100, 5
# 都用num_hiddens是为了省事,其实应该是q-k-v的特征数和hiddens
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
  • 1
  • 2
  • 3
  • 4
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

作业二:修改后返回了分数,形状为(num_heads,batch_size,num_queries,num_kvpairs)

可以通过for循环直接打印每个头的分数,比较占篇幅,我就只打印了一个,每个头的分数形状为(batch_size,num_queries,num_kvpairs),结果确实如此

batch_size, num_queries = 2, 3
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
# X的形状-batchsize,查询个数,查询的特征数
X = torch.ones((batch_size, num_queries, num_hiddens))
# Y的形状-batchsize,键值的个数,键值的特征数
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# 这里的键和值设置的是一样的
att,score=attention(X, Y, Y, valid_lens)
print(att.shape,score.shape)

# 打印每个头的打分函数
# for i in range(score.shape[0]):
#     print(f'heads{i}:{score[i]}')
#     print(f'heads{i}.shape:{score[i].shape}')
print(f'heads{0}:{score[0]}')
print(f'heads{0}.shape:{score[0].shape}')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
torch.Size([2, 3, 100]) torch.Size([5, 2, 3, 6])
heads0:tensor([[[ 1.1521e-01,  1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06],
         [ 1.1521e-01,  1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06],
         [ 1.1521e-01,  1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06]],

        [[ 1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06],
         [ 1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06],
         [ 1.1521e-01,  1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06]]], grad_fn=<SelectBackward0>)
heads0.shape:torch.Size([2, 3, 6])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

作业三:随机初始化一个矩阵,模拟自注意力的前向计算过程。

结果的形状比较符合预期

num_hiddens, num_heads = 500, 5
batch_size, num_queries = 3, 6
num_kvpairs, valid_lens =  8, torch.tensor([3, 2, 3])

# X的形状-batchsize,查询个数,查询的特征数
X = torch.rand((batch_size, num_queries, num_hiddens))
# Y的形状-batchsize,键值的个数,键值的特征数
Y = torch.rand((batch_size, num_kvpairs, num_hiddens))
attention1 = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
# 这里的键和值设置的是一样的
att,score=attention1(X, Y, Y, valid_lens)
print(f'att的形状:{att.shape},多头socre的形状:{score.shape}')
print(f'heads{0}:{score[0]}')
print(f'heads{0}.shape:{score[0].shape}')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
att的形状:torch.Size([3, 6, 500]),多头socre的形状:torch.Size([5, 3, 6, 8])
heads0:tensor([[[-2.2675e-02, -2.0155e-02, -1.6584e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [ 3.4621e-02,  2.7170e-02, -1.0288e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [ 4.4613e-02,  1.7425e-02, -4.0215e-02, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-2.1214e-02, -7.5245e-02, -1.2768e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-6.9279e-02, -1.0662e-01, -2.2013e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-9.0299e-02, -6.2465e-02, -2.1125e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06]],

        [[-2.0877e-01, -3.1303e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-1.1978e-01, -2.2805e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-1.6171e-01, -2.3794e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-1.6123e-01, -3.0902e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-5.1692e-02, -1.3835e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-3.8984e-02, -2.4244e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06]],

        [[-1.0766e-01, -1.5127e-01, -1.3136e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-1.9279e-01, -2.0913e-01, -1.3926e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-5.1507e-02, -6.8220e-02, -1.5038e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-6.9059e-02, -1.1010e-01, -1.2919e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-5.1092e-02, -7.4127e-02, -8.4521e-02, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06],
         [-9.9756e-02, -1.8430e-01, -2.1404e-01, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06]]], grad_fn=<SelectBackward0>)
heads0.shape:torch.Size([3, 6, 8])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

自注意力模型实践

有了注意力机制之后,我们将词元序列输入注意力汇聚中,以便同一组词元同时充当查询、键和值。 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为自注意力。

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,我们通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。\

接下来,我们描述的是基于正弦函数和余弦函数的固定位置编码。假设输入为X,形状为(批量大小,词元个数,特征维度),位置编码使用相同形状的位置嵌入矩阵P, 输出为 X+P。\

位置嵌入矩阵P的计算公式为:
$ p_{i,2j} = \sin{(\frac{i}{10000^{2j/d}})}$ , $ p_{i,2j+1} = \cos{(\frac{i}{10000^{2j/d}})}$

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的矩阵P
        self.P = torch.zeros((1, max_len, num_hiddens))

        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1
              )/torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
batch_size, encoding_dim, num_steps =2, 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((batch_size, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]

print('输出特征:', X)
print('特征的形状:', X.shape)
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
输出特征: tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]]])
特征的形状: torch.Size([2, 60, 32])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
# 上周实现的模块
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

    
def split_head_reshape(X, heads_num, head_size):
    batch_size, seq_len, hidden_size = X.shape
    X = X.reshape(batch_size, seq_len, heads_num, head_size)
    X = X.permute(0, 2, 1, 3)
    X = X.reshape(batch_size, seq_len, head_size * heads_num)
    return X


class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, inputs_size, heads_num, dropout=0.0):
        super(MultiHeadSelfAttention, self).__init__()
        self.inputs_size = inputs_size
        self.qsize, self.ksize, self.vsize = inputs_size, inputs_size, inputs_size

        self.heads_num = heads_num
        self.head_size = inputs_size // heads_num
        self.Q_proj = nn.Linear(self.qsize, inputs_size)
        self.K_proj = nn.Linear(self.ksize, inputs_size)
        self.V_proj = nn.Linear(self.vsize, inputs_size)
        self.out_proj = nn.Linear(inputs_size, inputs_size)
        self.attention = DotProductAttention(0.5)

        self.positional_encoding=PositionalEncoding(inputs_size, 0)
    
    def forward(self, X, valid_lens):

        # 问题: 加入位置编码
        X = X + self.positional_encoding(X)

        self.batch_size, self.seq_len, self.hidden_size = X.shape
        Q, K, V = self.Q_proj(X), self.K_proj(X), self.V_proj(X)
        Q, K, V = [split_head_reshape(item, self.heads_num, self.head_size) for item in [Q, K ,V]]
        
        out = self.attention(Q, K, V, valid_lens)
        out = out.reshape(self.batch_size, self.heads_num, self.seq_len, self.head_size)
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(self.batch_size, self.seq_len, out.shape[2]*out.shape[3])
        out = self.out_proj(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
x= torch.rand((2, 2, 4))
valid_lens = torch.tensor([3]).repeat(x.shape[0])
head = MultiHeadSelfAttention(4, 1)

context = head(x, valid_lens)
print(context.shape)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
torch.Size([2, 2, 4])
  • 1

尝试定义LSTM与多头注意力结合的模型

我们想解决一个文本二分类的任务,每条数据是一段文字,标签值为0或1。
首先定义一个汇聚模块。一般来说,模型的输出的形状为 (批量大小, 时间步数, 特征维度), 汇聚模块将模型输出的所有位置上的特征维度进行平均,作为整段文字的表示,再送入分类器进行最后的类别判断。所以通过汇聚模块后,输出的维度减少了。

class AveragePooling(nn.Module):
    def __init__(self):
        super(AveragePooling, self).__init__()
        
    def forward(self, inputs, valid_length):

        valid_length = valid_length.unsqueeze(-1)
        print("valid_length.shape",valid_length.shape)
        
        # 有问题
        max_len = inputs.shape[1]
        # max_len = valid_length.shape[1]
    
        mask = torch.arange(max_len) < valid_length
        print("mask.shape:",mask.shape,"mask:",mask)
        mask = mask.unsqueeze(-1)
        print("mask.unsq:",mask.shape)
        inputs = torch.multiply(inputs, mask)
        print("inputs.shape:",inputs.shape)
        mean_outputs = torch.divide(torch.sum(inputs, dim=1), valid_length)
        print("meao",mean_outputs.shape)
        return mean_outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

接下来,实现一个LSTM与多头注意力结合的网络模型。

class lstmWithAttention(nn.Module):
    def __init__(self, hidden_size, embedding_size, vocab_size, n_classes, attention=None):
        super(lstmWithAttention, self).__init__()
        
        self.hidden_size = hidden_size
        # 表示词向量的维度
        self.embedding_size = embedding_size
        # 词表的大小, 即包含词的数量
        self.vocab_size = vocab_size
        self.n_classes = n_classes
        
        # 以下定义了一些前向计算的模块
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)

        # bug2:LSTM中不能少了batch_first=True
        self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=True, batch_first=True)
        # self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=True) 原代码

        self.attention = attention
        self.pooling = AveragePooling()
        self.cls_fn = nn.Linear(self.hidden_size * 2, self.n_classes)
        
    def forward(self, inputs):
        valid_lens = torch.tensor([10]).repeat(inputs.shape[0]) # batchsize
        print("valid_lens",valid_lens.shape)
        
        # 将词索引变为词的向量
        embedded_input = self.embedding(inputs) 
        # embedded_input的形状 batchsize,序列长度,embedding_size
        print("embedded_input",embedded_input.shape)

        # last_layers_hiddens的形状 batchsize-序列长度-hidden_size*2
        last_layers_hiddens, state = self.lstm(embedded_input) #有问题
        print("last_layers_hiddens",last_layers_hiddens.shape)

        # 调用多头自注意力机制 batchsize-序列长度-hidden_size*2
        last_layers_hiddens = self.attention(last_layers_hiddens, valid_lens)
        
        # bug3:汇聚层的代码有问题
        last_layers_hiddens = self.pooling(last_layers_hiddens, valid_lens)
        
        print('pooling:',last_layers_hiddens.shape)
        logits = self.cls_fn(last_layers_hiddens)
        
        return logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
hidden_size = 32
embedding_size = 32
n_classes = 2
n_layers = 1
# vocab_size是词表的大小, 即包含词的数量,这里假设是50000
vocab_size = 50000


head = MultiHeadSelfAttention(inputs_size=64, heads_num=2)
# 定义模型
model = lstmWithAttention(hidden_size, embedding_size, vocab_size, n_classes, attention=head)

print(model)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
lstmWithAttention(
  (embedding): Embedding(50000, 32)
  (lstm): LSTM(32, 32, batch_first=True, bidirectional=True)
  (attention): MultiHeadSelfAttention(
    (Q_proj): Linear(in_features=64, out_features=64, bias=True)
    (K_proj): Linear(in_features=64, out_features=64, bias=True)
    (V_proj): Linear(in_features=64, out_features=64, bias=True)
    (out_proj): Linear(in_features=64, out_features=64, bias=True)
    (attention): DotProductAttention(
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (positional_encoding): PositionalEncoding(
      (dropout): Dropout(p=0, inplace=False)
    )
  )
  (pooling): AveragePooling()
  (cls_fn): Linear(in_features=64, out_features=2, bias=True)
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
x = torch.randint(low=0, high=50000, size=(50, 20))
x.shape
  • 1
  • 2
torch.Size([50, 20])
  • 1
model(x).shape
  • 1
valid_lens torch.Size([50])
embedded_input torch.Size([50, 20, 32])
last_layers_hiddens torch.Size([50, 20, 64])
valid_length.shape torch.Size([50, 1])
mask.shape: torch.Size([50, 20]) mask: tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False]])
mask.unsq: torch.Size([50, 20, 1])
inputs.shape: torch.Size([50, 20, 64])
meao torch.Size([50, 64])
pooling: torch.Size([50, 64])





torch.Size([50, 2])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/237467
推荐阅读
相关标签
  

闽ICP备14008679号