当前位置:   article > 正文

pytorch实现自注意力(self-attention)_pytorch self attention

pytorch self attention
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention_Layer(nn.Module):
    
    #用来实现mask-attention layer
    def __init__(self, hidden_dim, is_bi_rnn):
        super(Attention_Layer,self).__init__()
        
        self.hidden_dim = hidden_dim
        self.is_bi_rnn = is_bi_rnn
        
        #下面使用nn的Linear层来定义Q,K,V矩阵
        if is_bi_rnn:
            #是双向的RNN
            self.Q_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
            self.K_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
            self.V_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
        else:
            #单向的RNN
            self.Q_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
            self.K_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
            self.V_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
            
            
        
    def forward(self, inputs, lens):
        
        size = inputs.size()
        #计算生成QKV矩阵
        Q = self.Q_linear(inputs) 
        K = self.K_linear(inputs).permute(0, 2, 1)#先进行一次转置
        V = self.V_linear(inputs)
        
        #还要计算生成mask矩阵
        max_len = max(lens) #最大的句子长度,生成mask矩阵
        sentence_lengths = torch.Tensor(lens) # 代表每个句子的长度
        mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
        mask = mask.unsqueeze(dim = 1) #[batch_size, 1, max_len]
        mask = mask.expand(size[0], max_len, max_len) #[batch_size, max_len, max_len]
        
        #print('\nmask is :', mask.size())
        
        #下面生成用来填充的矩阵
        padding_num = torch.ones_like(mask)
        padding_num = -2**31 * padding_num.float()
        
        #print('\npadding num is :', padding_num.size())
        
        #下面开始计算啦
        alpha = torch.matmul(Q, K)

        #下面开始mask
        alpha = torch.where(mask, alpha, padding_num)
        #下面开始softmax
        alpha = F.softmax(alpha, dim = 2)
        #print('\nalpha is :', alpha)
        
        out = torch.matmul(alpha, V)
        
        return out
        

if __name__ == '__main__':
        
    out = torch.rand(3,10,128) #这里假设是RNN的输出,维度分别是[batch_size, max_len, hidden_size * 2]
    att_L = Attention_Layer(64, True) # 参数分别是 hidden_size, 双向RNN:True
    lens = [7, 10, 4]  #一个batch文本的真实长度
    
    att_out = att_L(out, lens) #开始计算
        
        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/351225
推荐阅读
相关标签
  

闽ICP备14008679号