当前位置:   article > 正文

【手撕Self-Attention】self-Attention的numpy实现和pytorch实现_手撕attention

手撕attention

在这里插入图片描述

理论思想


参考我的博客


Self-Attention numpy实现


import numpy as np
from numpy.random import randn
  • 1
  • 2

在这里插入图片描述

d = 256 #dimension
n = 32 #32个序列
x = randn(d,n)
x.shape
  • 1
  • 2
  • 3
  • 4
(256, 32)
  • 1

在这里插入图片描述

w_q = randn(d,d)
w_k = randn(d,d)
w_v = randn(d,d)

q = w_q @ x
k = w_k @ x
v = w_v @ x

q.shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
(256, 32)
  • 1

在这里插入图片描述

A = k.T @ q
A.shape,v.shape
  • 1
  • 2
((32, 32), (256, 32))
  • 1

在这里插入图片描述

计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度 d k \sqrt{d_{k}} dk , 其中 d k d_{k} dk 为一个query和key向量的维度。

A /= np.sqrt(d)
A
  • 1
  • 2
array([[ 467.16794715,  380.39806016, -360.37257332, ..., -615.72400039,
        -212.19910996,  -37.3895145 ],
       [ 310.8283912 ,  -39.95152262,   77.53697612, ...,   58.05146241,
         178.94456822,  201.25240535],
       [-530.60357521, -179.45154641, -141.37155644, ...,  449.33825889,
         627.32801325,  271.24241891],
       ...,
       [ -71.70217207, -123.44349477,  158.86743499, ...,  285.45057475,
         115.29456888,  -10.19809743],
       [  78.8631081 ,  411.99848162,   96.31579751, ...,    7.30553679,
         457.97307249,  287.97459693],
       [ 535.17497128,  258.96003762,  -77.00325112, ..., -234.59521067,
         240.60742865,  211.40853578]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
def softmax(x):
    e_x = np.exp(x-np.max(x))# 防溢出
    return e_x/e_x.sum(axis=0)
  • 1
  • 2
  • 3
A_hat = softmax(A)
A_hat
  • 1
  • 2
array([[2.91692116e-030, 1.34516507e-086, 0.00000000e+000, ...,
        0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
       [3.69418159e-098, 0.00000000e+000, 0.00000000e+000, ...,
        0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
       [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
        1.83462355e-049, 1.00000000e+000, 3.65481942e-133],
       ...,
       [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
        1.22470105e-120, 0.00000000e+000, 0.00000000e+000],
       [0.00000000e+000, 7.12302144e-073, 0.00000000e+000, ...,
        0.00000000e+000, 2.81892644e-074, 6.75396039e-126],
       [1.00000000e+000, 2.44856986e-139, 0.00000000e+000, ...,
        0.00000000e+000, 1.12042830e-168, 3.77862069e-159]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在这里插入图片描述

output = v @ A_hat
output
  • 1
  • 2
array([[-16.27732457,   5.01921581,   7.96516927, ..., -10.3259554 ,
         10.76273056,  -4.27466928],
       [  3.17159406,   5.30274362, -23.31649146, ...,  13.74964699,
          8.92598827, -10.29046234],
       [ -9.85427512,  -6.53049358,  -6.49381562, ...,  -2.14490419,
          1.16150094,  -1.03177095],
       ...,
       [  7.30371807,   8.1844567 ,  11.42067085, ...,   2.66942536,
          3.87896518, -11.65066698],
       [ 19.24926147,   2.48411984,  -1.61712345, ...,  11.1749362 ,
          0.41691663,   8.12821816],
       [  7.4601276 ,   2.60847536,  -8.47258352, ...,  33.62259747,
          7.92981574,  -5.35334156]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
output.shape
  • 1
(256, 32)
  • 1

Self-Attention pytorch实现


from math import sqrt
import torch
import torch.nn as nn
  • 1
  • 2
  • 3
# Self-Attention 机制的实现
class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self,input_dim,dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        Q = self.q(x) # Q: batch_size * seq_len * dim_k
        K = self.k(x) # K: batch_size * seq_len * dim_k
        V = self.v(x) # V: batch_size * seq_len * dim_v
         
        atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_len
        
        output = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_v
        
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • K.permute(0,2,1)# 将K的维度索引1和维度索引2交换位置
  • torch.bmm# 两个tensor的矩阵乘法
X = torch.randn(4,3,2)
X
  • 1
  • 2
tensor([[[-0.8764,  1.5286],
         [-1.4622,  0.0379],
         [-0.4678, -0.5522]],

        [[-1.1969, -0.0225],
         [-0.9853, -1.3157],
         [-0.8873,  0.0357]],

        [[ 1.5263, -1.0748],
         [ 2.3504, -0.0865],
         [-0.4937, -1.3872]],

        [[ 0.3569, -1.6826],
         [ 0.5223,  1.3726],
         [ 0.7659, -1.2728]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
X.shape
  • 1
torch.Size([4, 3, 2])
  • 1
self_attention = Self_Attention(2,4,5)
res = self_attention(X)
res
  • 1
  • 2
  • 3
tensor([[[ 0.2684,  0.2199, -0.4360, -0.1844, -0.0310],
         [ 0.2813,  0.2159, -0.4956, -0.1817, -0.0702],
         [ 0.2819,  0.2511, -0.4872, -0.1941, -0.0687]],

        [[ 0.2818,  0.5077, -0.4062, -0.2856, -0.0447],
         [ 0.2828,  0.5048, -0.4117, -0.2844, -0.0481],
         [ 0.2816,  0.5136, -0.4034, -0.2877, -0.0435]],

        [[ 0.1872,  0.0663, -0.1166, -0.1376,  0.1993],
         [ 0.1331, -0.4651, -0.0385,  0.0464,  0.3131],
         [ 0.2440,  0.5729, -0.2148, -0.3125,  0.0749]],

        [[ 0.2048,  0.0763, -0.1932, -0.1394,  0.1472],
         [ 0.2002,  0.1248, -0.1574, -0.1571,  0.1653],
         [ 0.2040,  0.1136, -0.1779, -0.1528,  0.1530]]],
       grad_fn=<BmmBackward0>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
res.shape
  • 1
torch.Size([4, 3, 5])
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/351228
推荐阅读
相关标签
  

闽ICP备14008679号