赞
踩
参考我的博客
import numpy as np
from numpy.random import randn
d = 256 #dimension
n = 32 #32个序列
x = randn(d,n)
x.shape
(256, 32)
w_q = randn(d,d)
w_k = randn(d,d)
w_v = randn(d,d)
q = w_q @ x
k = w_k @ x
v = w_v @ x
q.shape
(256, 32)
A = k.T @ q
A.shape,v.shape
((32, 32), (256, 32))
计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度 d k \sqrt{d_{k}} dk , 其中 d k d_{k} dk 为一个query和key向量的维度。
A /= np.sqrt(d)
A
array([[ 467.16794715, 380.39806016, -360.37257332, ..., -615.72400039,
-212.19910996, -37.3895145 ],
[ 310.8283912 , -39.95152262, 77.53697612, ..., 58.05146241,
178.94456822, 201.25240535],
[-530.60357521, -179.45154641, -141.37155644, ..., 449.33825889,
627.32801325, 271.24241891],
...,
[ -71.70217207, -123.44349477, 158.86743499, ..., 285.45057475,
115.29456888, -10.19809743],
[ 78.8631081 , 411.99848162, 96.31579751, ..., 7.30553679,
457.97307249, 287.97459693],
[ 535.17497128, 258.96003762, -77.00325112, ..., -234.59521067,
240.60742865, 211.40853578]])
def softmax(x):
e_x = np.exp(x-np.max(x))# 防溢出
return e_x/e_x.sum(axis=0)
A_hat = softmax(A)
A_hat
array([[2.91692116e-030, 1.34516507e-086, 0.00000000e+000, ...,
0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
[3.69418159e-098, 0.00000000e+000, 0.00000000e+000, ...,
0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
1.83462355e-049, 1.00000000e+000, 3.65481942e-133],
...,
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
1.22470105e-120, 0.00000000e+000, 0.00000000e+000],
[0.00000000e+000, 7.12302144e-073, 0.00000000e+000, ...,
0.00000000e+000, 2.81892644e-074, 6.75396039e-126],
[1.00000000e+000, 2.44856986e-139, 0.00000000e+000, ...,
0.00000000e+000, 1.12042830e-168, 3.77862069e-159]])
output = v @ A_hat
output
array([[-16.27732457, 5.01921581, 7.96516927, ..., -10.3259554 ,
10.76273056, -4.27466928],
[ 3.17159406, 5.30274362, -23.31649146, ..., 13.74964699,
8.92598827, -10.29046234],
[ -9.85427512, -6.53049358, -6.49381562, ..., -2.14490419,
1.16150094, -1.03177095],
...,
[ 7.30371807, 8.1844567 , 11.42067085, ..., 2.66942536,
3.87896518, -11.65066698],
[ 19.24926147, 2.48411984, -1.61712345, ..., 11.1749362 ,
0.41691663, 8.12821816],
[ 7.4601276 , 2.60847536, -8.47258352, ..., 33.62259747,
7.92981574, -5.35334156]])
output.shape
(256, 32)
from math import sqrt
import torch
import torch.nn as nn
# Self-Attention 机制的实现 class Self_Attention(nn.Module): # input : batch_size * seq_len * input_dim # q : batch_size * input_dim * dim_k # k : batch_size * input_dim * dim_k # v : batch_size * input_dim * dim_v def __init__(self,input_dim,dim_k,dim_v): super(Self_Attention,self).__init__() self.q = nn.Linear(input_dim,dim_k) self.k = nn.Linear(input_dim,dim_k) self.v = nn.Linear(input_dim,dim_v) self._norm_fact = 1 / sqrt(dim_k) def forward(self,x): Q = self.q(x) # Q: batch_size * seq_len * dim_k K = self.k(x) # K: batch_size * seq_len * dim_k V = self.v(x) # V: batch_size * seq_len * dim_v atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_len output = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_v return output
X = torch.randn(4,3,2)
X
tensor([[[-0.8764, 1.5286],
[-1.4622, 0.0379],
[-0.4678, -0.5522]],
[[-1.1969, -0.0225],
[-0.9853, -1.3157],
[-0.8873, 0.0357]],
[[ 1.5263, -1.0748],
[ 2.3504, -0.0865],
[-0.4937, -1.3872]],
[[ 0.3569, -1.6826],
[ 0.5223, 1.3726],
[ 0.7659, -1.2728]]])
X.shape
torch.Size([4, 3, 2])
self_attention = Self_Attention(2,4,5)
res = self_attention(X)
res
tensor([[[ 0.2684, 0.2199, -0.4360, -0.1844, -0.0310], [ 0.2813, 0.2159, -0.4956, -0.1817, -0.0702], [ 0.2819, 0.2511, -0.4872, -0.1941, -0.0687]], [[ 0.2818, 0.5077, -0.4062, -0.2856, -0.0447], [ 0.2828, 0.5048, -0.4117, -0.2844, -0.0481], [ 0.2816, 0.5136, -0.4034, -0.2877, -0.0435]], [[ 0.1872, 0.0663, -0.1166, -0.1376, 0.1993], [ 0.1331, -0.4651, -0.0385, 0.0464, 0.3131], [ 0.2440, 0.5729, -0.2148, -0.3125, 0.0749]], [[ 0.2048, 0.0763, -0.1932, -0.1394, 0.1472], [ 0.2002, 0.1248, -0.1574, -0.1571, 0.1653], [ 0.2040, 0.1136, -0.1779, -0.1528, 0.1530]]], grad_fn=<BmmBackward0>)
res.shape
torch.Size([4, 3, 5])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。