当前位置:   article > 正文

Transformer中的Multi-Head Attention结构

Transformer中的Multi-Head Attention结构

1. Multi-Head Attention结构图

在这里插入图片描述

  • Q , K , V Q,K,V Q,K,V 可以根据 H e a d Head Head 的数量 h h h 等分成 { Q 1 , K 1 , V 1 } , ⋯   , { Q h , K h , V h } \{Q_1,K_1,V_1\},\cdots, \{Q_h,K_h,V_h\} {Q1,K1,V1},,{Qh,Kh,Vh} ,输入到 h h h S e l f A t t e n t i o n Self Attention SelfAttention 结构中;
  • h h h S e l f A t t e n t i o n Self Attention SelfAttention 结构的输出拼接在一起做一个线性变换进行输出

2. Multi-Head Attention的计算流程

2.1 Q Q Q, K K K, V V V 的获取

可以根据输入的词向量矩阵 X ^ \hat{X} X^ 来进行线性变换得到 Q , K , V Q, K, V Q,K,V​ ,即
X ^ W Q = Q X ^ W K = K X ^ W V = V

(1)X^WQ=Q(2)X^WK=K(3)X^WV=V
X^WQ=QX^WK=KX^WV=V(1)(2)(3)
这里 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV​ 是线性变换矩阵。

为了便于理解,我们让输入的词向量 X ^ \hat{X} X^ 维度为 4 × 6 4\times6 4×6,线性变换矩阵 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV 的维度为 6 × 6 6\times 6 6×6, 如下图所示

在这里插入图片描述

根据公式(1),(2)和(3)计算得到的 Q , K , V Q, K, V Q,K,V 维度大小就是 4 × 6 4 \times 6 4×6,如下图所示

在这里插入图片描述

假设 H e a d Head Head 的数量为3,那么我们就可以将 Q , K , V Q,K,V Q,K,V 等分成三份,如下图所示

在这里插入图片描述

2.2 { Q 1 , K 1 , V 1 } , ⋯   , { Q h , K h , V h } \{Q_1,K_1,V_1\},\cdots, \{Q_h,K_h,V_h\} {Q1,K1,V1},,{Qh,Kh,Vh} 获取之后的各自self-attention输出

输出计算公式为
A t t e n t i o n ( Q i , K i , V i ) = s o f t m a x ( Q i K i T d k ) V i , (4) \mathrm{Attention}(Q_i,K_i,V_i)=\mathrm{softmax}(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i,\tag{4} Attention(Qi,Ki,Vi)=softmax(dk QiKiT)Vi,(4)
公式(4)中 d k d_k dk 就是上图中 { Q h , K h , V h } \{Q_h,K_h,V_h\} {Qh,Kh,Vh} 的列数,即 d k = 2 d_k=2 dk=2,如下图所示

在这里插入图片描述

根据公式(4)可以推断每个self-attention的输出维度大小为 4\times 2$, 如下图所示

在这里插入图片描述

concatenation的作用就是将三个self-attention的输出拼接起来,如下图所示

在这里插入图片描述

2.3 一个简单的例子来模拟multi-head attention 的计算流程

随机产生一个 4 × 6 4\times 6 4×6 大小的矩阵充当 X ^ \hat{X} X^
X ^ = [ 0.22 0.87 0.21 0.92 0.49 0.61 0.77 0.52 0.3 0.19 0.08 0.74 0.44 0.16 0.88 0.27 0.41 0.3 0.63 0.58 0.6 0.27 0.28 0.25 ] \hat{X}=

[0.220.870.210.920.490.610.770.520.30.190.080.740.440.160.880.270.410.30.630.580.60.270.280.25]
X^= 0.220.770.440.630.870.520.160.580.210.30.880.60.920.190.270.270.490.080.410.280.610.740.30.25
同时,我们也随机产生三个线性变换矩阵 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV ,即
W Q = [ 0.33 0.14 0.17 0.96 0.96 0.19 0.02 0.2 0.7 0.78 0.02 0.58 0. 0.52 0.64 0.99 0.26 0.8 0.87 0.92 0. 0.47 0.98 0.4 0.81 0.55 0.77 0.48 0.03 0.09 0.11 0.25 0.96 0.63 0.82 0.57 ] W^Q=
[0.330.140.170.960.960.190.020.20.70.780.020.580.0.520.640.990.260.80.870.920.0.470.980.40.810.550.770.480.030.090.110.250.960.630.820.57]
WQ= 0.330.020.0.870.810.110.140.20.520.920.550.250.170.70.640.0.770.960.960.780.990.470.480.630.960.020.260.980.030.820.190.580.80.40.090.57

W K = [ 0.64 0.81 0.93 0.91 0.82 0.09 0.36 0.04 0.55 0.80 0.05 0.19 0.37 0.24 0.80 0.35 0.64 0.49 0.58 0.94 0.94 0.11 0.84 0.35 0.10 0.38 0.51 0.96 0.37 0.01 0.86 0.11 0.48 0.85 0.51 0.45 ] W^K=

[0.640.810.930.910.820.090.360.040.550.800.050.190.370.240.800.350.640.490.580.940.940.110.840.350.100.380.510.960.370.010.860.110.480.850.510.45]
WK= 0.640.360.370.580.100.860.810.040.240.940.380.110.930.550.800.940.510.480.910.800.350.110.960.850.820.050.640.840.370.510.090.190.490.350.010.45

W V = [ 0.80 0.02 0.57 0.41 0.99 0.80 0.05 0.19 0.45 0.70 0.33 0.36 0.92 0.95 0.41 0.90 0.33 0.08 0.53 0.66 0.89 0.97 0.77 0.76 0.71 0.70 0.77 0.97 0.37 0.08 0.24 0.22 0.36 0.81 0.06 0.45 ] W^V=

[0.800.020.570.410.990.800.050.190.450.700.330.360.920.950.410.900.330.080.530.660.890.970.770.760.710.700.770.970.370.080.240.220.360.810.060.45]
WV= 0.800.050.920.530.710.240.020.190.950.660.700.220.570.450.410.890.770.360.410.700.900.970.970.810.990.330.330.770.370.060.800.360.080.760.080.45

根据公式(1),(2)和(3),我们可以得到 Q , K , V Q, K, V Q,K,V​ ,即
Q = X ^ W Q = [ 0.22 0.87 0.21 0.92 0.49 0.61 0.77 0.52 0.3 0.19 0.08 0.74 0.44 0.16 0.88 0.27 0.41 0.3 0.63 0.58 0.6 0.27 0.28 0.25 ] × [ 0.33 0.14 0.17 0.96 0.96 0.19 0.02 0.2 0.7 0.78 0.02 0.58 0. 0.52 0.64 0.99 0.26 0.8 0.87 0.92 0. 0.47 0.98 0.4 0.81 0.55 0.77 0.48 0.03 0.09 0.11 0.25 0.96 0.63 0.82 0.57 ] = [ 1.3544 1.5824 1.7437 2.1496 1.6997 1.4742 0.5760 0.7716 1.4589 2.0357 1.6230 1.1929 0.7484 1.1001 1.3537 1.9311 1.1773 1.1963 0.7087 0.9811 1.3527 2.0700 1.2504 1.2118 ]

Q=X^WQ=[0.220.870.210.920.490.610.770.520.30.190.080.740.440.160.880.270.410.30.630.580.60.270.280.25]×[0.330.140.170.960.960.190.020.20.70.780.020.580.0.520.640.990.260.80.870.920.0.470.980.40.810.550.770.480.030.090.110.250.960.630.820.57]=[1.35441.58241.74372.14961.69971.47420.57600.77161.45892.03571.62301.19290.74841.10011.35371.93111.17731.19630.70870.98111.35272.07001.25041.2118]
Q=X^WQ= 0.220.770.440.630.870.520.160.580.210.30.880.60.920.190.270.270.490.080.410.280.610.740.30.25 × 0.330.020.0.870.810.110.140.20.520.920.550.250.170.70.640.0.770.960.960.780.990.470.480.630.960.020.260.980.030.820.190.580.80.40.090.57 = 1.35440.57600.74840.70871.58240.77161.10010.98111.74371.45891.35371.35272.14962.03571.93112.07001.69971.62301.17731.25041.47421.19291.19631.2118

K = X ^ W K = [ 0.22 0.87 0.21 0.92 0.49 0.61 0.77 0.52 0.3 0.19 0.08 0.74 0.44 0.16 0.88 0.27 0.41 0.3 0.63 0.58 0.6 0.27 0.28 0.25 ] × [ 0.64 0.81 0.93 0.91 0.82 0.09 0.36 0.04 0.55 0.80 0.05 0.19 0.37 0.24 0.80 0.35 0.64 0.49 0.58 0.94 0.94 0.11 0.84 0.35 0.10 0.38 0.51 0.96 0.37 0.01 0.86 0.11 0.48 0.85 0.51 0.45 ] = [ 1.6389 1.3815 2.2586 2.0598 1.6235 0.8894 1.5456 1.0069 1.8167 1.9484 1.4160 0.7154 1.1204 1.0166 1.8081 1.5147 1.4635 0.7348 1.2336 1.0652 1.9015 1.7583 1.3875 0.6707 ]

K=X^WK=[0.220.870.210.920.490.610.770.520.30.190.080.740.440.160.880.270.410.30.630.580.60.270.280.25]×[0.640.810.930.910.820.090.360.040.550.800.050.190.370.240.800.350.640.490.580.940.940.110.840.350.100.380.510.960.370.010.860.110.480.850.510.45]=[1.63891.38152.25862.05981.62350.88941.54561.00691.81671.94841.41600.71541.12041.01661.80811.51471.46350.73481.23361.06521.90151.75831.38750.6707]
K=X^WK= 0.220.770.440.630.870.520.160.580.210.30.880.60.920.190.270.270.490.080.410.280.610.740.30.25 × 0.640.360.370.580.100.860.810.040.240.940.380.110.930.550.800.940.510.480.910.800.350.110.960.850.820.050.640.840.370.510.090.190.490.350.010.45 = 1.63891.54561.12041.23361.38151.00691.01661.06522.25861.81671.80811.90152.05981.94841.51471.75831.62351.41601.46351.38750.88940.71540.73480.6707

V = X ^ W V = [ 0.22 0.87 0.21 0.92 0.49 0.61 0.77 0.52 0.3 0.19 0.08 0.74 0.44 0.16 0.88 0.27 0.41 0.3 0.63 0.58 0.6 0.27 0.28 0.25 ] × [ 0.80 0.02 0.57 0.41 0.99 0.80 0.05 0.19 0.45 0.70 0.33 0.36 0.92 0.95 0.41 0.90 0.33 0.08 0.53 0.66 0.89 0.97 0.77 0.76 0.71 0.70 0.77 0.97 0.37 0.08 0.24 0.22 0.36 0.81 0.06 0.45 ] = [ 1.3946 1.4536 2.0187 2.7500 1.5005 1.5189 1.2531 0.7434 1.2930 1.8110 1.2532 1.3110 1.6758 1.4064 1.3476 1.9870 1.1564 0.8530 1.4869 1.1220 1.4120 1.9403 1.3396 1.1009 ]

V=X^WV=[0.220.870.210.920.490.610.770.520.30.190.080.740.440.160.880.270.410.30.630.580.60.270.280.25]×[0.800.020.570.410.990.800.050.190.450.700.330.360.920.950.410.900.330.080.530.660.890.970.770.760.710.700.770.970.370.080.240.220.360.810.060.45]=[1.39461.45362.01872.75001.50051.51891.25310.74341.29301.81101.25321.31101.67581.40641.34761.98701.15640.85301.48691.12201.41201.94031.33961.1009]
V=X^WV= 0.220.770.440.630.870.520.160.580.210.30.880.60.920.190.270.270.490.080.410.280.610.740.30.25 × 0.800.050.920.530.710.240.020.190.950.660.700.220.570.450.410.890.770.360.410.700.900.970.970.810.990.330.330.770.370.060.800.360.080.760.080.45 = 1.39461.25311.67581.48691.45360.74341.40641.12202.01871.29301.34761.41202.75001.81101.98701.94031.50051.25321.15641.33961.51891.31100.85301.1009

Q , K , V Q,K,V Q,K,V 拆分成 { Q 1 , K 1 , V 1 } \{Q_1,K_1,V_1\} {Q1,K1,V1}, { Q 2 , K 2 , V 2 } \{Q_2,K_2,V_2\} {Q2,K2,V2}, { Q 3 , K 3 , V 3 } \{Q_3,K_3,V_3\} {Q3,K3,V3}。即
{ Q 1 , K 1 , V 1 } = { [ 1.3544 1.5824 0.5760 0.7716 0.7484 1.1001 0.7087 0.9811 ] , [ 1.6389 1.3815 1.5456 1.0069 1.1204 1.0166 1.2336 1.0652 ] , [ 1.3946 1.4536 1.2531 0.7434 1.6758 1.4064 1.4869 1.1220 ] } \{Q_1,K_1,V_1\}=\{

[1.35441.58240.57600.77160.74841.10010.70870.9811]
,
[1.63891.38151.54561.00691.12041.01661.23361.0652]
,
[1.39461.45361.25310.74341.67581.40641.48691.1220]
\} {Q1,K1,V1}={ 1.35440.57600.74840.70871.58240.77161.10010.9811 , 1.63891.54561.12041.23361.38151.00691.01661.0652 , 1.39461.25311.67581.48691.45360.74341.40641.1220 }

{ Q 2 , K 2 , V 2 } = { [ 1.7437 2.1496 1.4589 2.0357 1.3537 1.9311 1.3527 2.0700 ] , [ 2.2586 2.0598 1.8167 1.9484 1.8081 1.5147 1.9015 1.7583 ] , [ 2.0187 2.7500 1.2930 1.8110 1.3476 1.9870 1.4120 1.9403 ] } \{Q_2,K_2,V_2\}=\{

[1.74372.14961.45892.03571.35371.93111.35272.0700]
,
[2.25862.05981.81671.94841.80811.51471.90151.7583]
,
[2.01872.75001.29301.81101.34761.98701.41201.9403]
\} {Q2,K2,V2}={ 1.74371.45891.35371.35272.14962.03571.93112.0700 , 2.25861.81671.80811.90152.05981.94841.51471.7583 , 2.01871.29301.34761.41202.75001.81101.98701.9403 }

{ Q 3 , K 3 , V 3 } = { [ 1.6997 1.4742 1.6230 1.1929 1.1773 1.1963 1.2504 1.2118 ] , [ 1.6235 0.8894 1.4160 0.7154 1.4635 0.7348 1.3875 0.6707 ] , [ 1.5005 1.5189 1.2532 1.3110 1.1564 0.8530 1.3396 1.1009 ] } \{Q_3,K_3,V_3\}=\{

[1.69971.47421.62301.19291.17731.19631.25041.2118]
,
[1.62350.88941.41600.71541.46350.73481.38750.6707]
,
[1.50051.51891.25321.31101.15640.85301.33961.1009]
\} {Q3,K3,V3}={ 1.69971.62301.17731.25041.47421.19291.19631.2118 , 1.62351.41601.46351.38750.88940.71540.73480.6707 , 1.50051.25321.15641.33961.51891.31100.85301.1009 }

对于 { Q 1 , K 1 , V 1 } \{Q_1,K_1,V_1\} {Q1,K1,V1},我们首先计算
Q 1 K 1 T d k = 1 2 × [ 1.3544 1.5824 0.5760 0.7716 0.7484 1.1001 0.7087 0.9811 ] × [ 1.6389 1.5456 1.1204 1.2336 1.3815 1.0069 1.0166 1.0652 ] = [ 3.1154 2.6069 2.2105 2.3733 1.4213 1.1789 1.0110 1.0836 1.9420 1.6012 1.3837 1.4814 1.7797 1.4731 1.2667 1.3572 ]

Q1K1Tdk=12×[1.35441.58240.57600.77160.74841.10010.70870.9811]×[1.63891.54561.12041.23361.38151.00691.01661.0652]=[3.11542.60692.21052.37331.42131.17891.01101.08361.94201.60121.38371.48141.77971.47311.26671.3572]
dk Q1K1T=2 1× 1.35440.57600.74840.70871.58240.77161.10010.9811 ×[1.63891.38151.54561.00691.12041.01661.23361.0652]= 3.11541.42131.94201.77972.60691.17891.60121.47312.21051.01101.38371.26672.37331.08361.48141.3572
应用softmax去对结果每一行进行softmax,即
s o f t m a x ( Q 1 K 1 T d k ) = [ 0.4029 0.2423 0.1630 0.1918 0.3163 0.2482 0.2098 0.2257 0.3431 0.2440 0.1963 0.2165 0.3344 0.2461 0.2002 0.2192 ] \mathrm{softmax}(\frac{Q_1K_1^T}{\sqrt{d_k}})=
[0.40290.24230.16300.19180.31630.24820.20980.22570.34310.24400.19630.21650.33440.24610.20020.2192]
softmax(dk Q1K1T)= 0.40290.31630.34310.33440.24230.24820.24400.24610.16300.20980.19630.20020.19180.22570.21650.2192

最后我们就可以用上面矩阵乘以 V 1 V_1 V1​得到这个自注意力机制的输出了,即
s o f t m a x ( Q 1 K 1 T d k ) V 1 = [ 0.4029 0.2423 0.1630 0.1918 0.3163 0.2482 0.2098 0.2257 0.3431 0.2440 0.1963 0.2165 0.3344 0.2461 0.2002 0.2192 ] × [ 1.3946 1.4536 1.2531 0.7434 1.6758 1.4064 1.4869 1.1220 ] = [ 1.4239 1.2102 1.4393 1.1926 1.4353 1.1992 1.4363 1.1967 ]
softmax(Q1K1Tdk)V1=[0.40290.24230.16300.19180.31630.24820.20980.22570.34310.24400.19630.21650.33440.24610.20020.2192]×[1.39461.45361.25310.74341.67581.40641.48691.1220]=[1.42391.21021.43931.19261.43531.19921.43631.1967]
softmax(dk Q1K1T)V1= 0.40290.31630.34310.33440.24230.24820.24400.24610.16300.20980.19630.20020.19180.22570.21650.2192 × 1.39461.25311.67581.48691.45360.74341.40641.1220 = 1.42391.43931.43531.43631.21021.19261.19921.1967

类似的,我们可以计算得到
s o f t m a x ( Q 2 K 2 T d k ) V 2 = [ 1.6599 2.2933 1.6423 2.2714 1.6340 2.2611 1.6382 2.2661 ] \mathrm{softmax}(\frac{Q_2K_2^T}{\sqrt{d_k}})V_2=
[1.65992.29331.64232.27141.63402.26111.63822.2661]
softmax(dk Q2K2T)V2= 1.65991.64231.63401.63822.29332.27142.26112.2661

s o f t m a x ( Q 3 K 3 T d k ) V 3 = [ 1.3315 1.2298 1.3292 1.2256 1.3262 1.2206 1.3268 1.2216 ] \mathrm{softmax}(\frac{Q_3K_3^T}{\sqrt{d_k}})V_3=

[1.33151.22981.32921.22561.32621.22061.32681.2216]
softmax(dk Q3K3T)V3= 1.33151.32921.32621.32681.22981.22561.22061.2216

对这三个self-attention输出进行拼接得到如下矩阵
Z = [ A t t e n t i o n 1 , A t t e n t i o n 2 , A t t e n t i o n 3 ] = [ 1.4239 1.2102 1.6599 2.2933 1.3315 1.2298 1.4393 1.1926 1.6423 2.2714 1.3292 1.2256 1.4353 1.1992 1.6340 2.2611 1.3262 1.2206 1.4363 1.1967 1.6382 2.2661 1.3268 1.2216 ] Z=[Attention1, Attention2, Attention3]=

[1.42391.21021.65992.29331.33151.22981.43931.19261.64232.27141.32921.22561.43531.19921.63402.26111.32621.22061.43631.19671.63822.26611.32681.2216]
Z=[Attention1,Attention2,Attention3]= 1.42391.43931.43531.43631.21021.19261.19921.19671.65991.64231.63401.63822.29332.27142.26112.26611.33151.32921.32621.32681.22981.22561.22061.2216
Z Z Z 做一个线性变换就得到最终的输出。这里我们产生一个矩阵维度为 6 × 6 6\times 6 6×6 的随机矩阵 W O W^O WO 用于做线性变换,即
W O = [ 0.81 0.26 0.06 0.24 0.09 0.81 0.17 0.20 0.81 0.81 0.59 0.91 0.06 0.96 0.57 0.30 0.83 0.66 0.99 0.11 0.58 0.47 0.65 0.24 0.03 0.54 0.36 0.89 0.46 0.42 0.63 0.53 0.96 0.79 0.50 0.21 ] W^O=
[0.810.260.060.240.090.810.170.200.810.810.590.910.060.960.570.300.830.660.990.110.580.470.650.240.030.540.360.890.460.420.630.530.960.790.500.21]
WO= 0.810.170.060.990.030.630.260.200.960.110.540.530.060.810.570.580.360.960.240.810.300.470.890.790.090.590.830.650.460.500.810.910.660.240.420.21

那么这个multi-head的最终输出为
o u t p u t = Z W O = [ 1.4239 1.2102 1.6599 2.2933 1.3315 1.2298 1.4393 1.1926 1.6423 2.2714 1.3292 1.2256 1.4353 1.1992 1.6340 2.2611 1.3262 1.2206 1.4363 1.1967 1.6382 2.2661 1.3268 1.2216 ] × [ 0.81 0.26 0.06 0.24 0.09 0.81 0.17 0.20 0.81 0.81 0.59 0.91 0.06 0.96 0.57 0.30 0.83 0.66 0.99 0.11 0.58 0.47 0.65 0.24 0.03 0.54 0.36 0.89 0.46 0.42 0.63 0.53 0.96 0.79 0.50 0.21 ] = [ 4.5438 3.8288 5.0019 5.0544 4.9380 4.7180 4.5279 3.8066 4.9610 5.0229 4.8970 4.6958 4.5117 3.7934 4.9495 5.0133 4.8830 4.6883 4.5179 3.7986 4.9539 5.0164 4.8890 4.6912 ]
output=ZWO=[1.42391.21021.65992.29331.33151.22981.43931.19261.64232.27141.32921.22561.43531.19921.63402.26111.32621.22061.43631.19671.63822.26611.32681.2216]×[0.810.260.060.240.090.810.170.200.810.810.590.910.060.960.570.300.830.660.990.110.580.470.650.240.030.540.360.890.460.420.630.530.960.790.500.21]=[4.54383.82885.00195.05444.93804.71804.52793.80664.96105.02294.89704.69584.51173.79344.94955.01334.88304.68834.51793.79864.95395.01644.88904.6912]
output=ZWO= 1.42391.43931.43531.43631.21021.19261.19921.19671.65991.64231.63401.63822.29332.27142.26112.26611.33151.32921.32621.32681.22981.22561.22061.2216 × 0.810.170.060.990.030.630.260.200.960.110.540.530.060.810.570.580.360.960.240.810.300.470.890.790.090.590.830.650.460.500.810.910.660.240.420.21 = 4.54384.52794.51174.51793.82883.80663.79343.79865.00194.96104.94954.95395.05445.02295.01335.01644.93804.89704.88304.88904.71804.69584.68834.6912

3.计算流程代码

# 随机产生一个大小为4 x 6的词向量矩阵X_hat
np.random.seed(5)
x_hat = np.random.rand(4, 6)
x_hat = np.round(x_hat, 2)  # 保留两位小数
print(f"输入词向量矩阵x_hat为:\n{x_hat}")
# 输入词向量矩阵x_hat为:
# [[0.22 0.87 0.21 0.92 0.49 0.61]
#  [0.77 0.52 0.3  0.19 0.08 0.74]
#  [0.44 0.16 0.88 0.27 0.41 0.3 ]
#  [0.63 0.58 0.6  0.27 0.28 0.25]]

# 随机产生一个大小为6 x 6的线性变换矩阵W_Q
W_Q = np.random.rand(6, 6)
W_Q = np.round(W_Q, 2)  # 保留两位小数
print(f"线性变换矩阵W_Q为:\n{W_Q}")
# 线性变换矩阵W_Q为:
# [[0.33 0.14 0.17 0.96 0.96 0.19]
#  [0.02 0.2  0.7  0.78 0.02 0.58]
#  [0.   0.52 0.64 0.99 0.26 0.8 ]
#  [0.87 0.92 0.   0.47 0.98 0.4 ]
#  [0.81 0.55 0.77 0.48 0.03 0.09]
#  [0.11 0.25 0.96 0.63 0.82 0.57]]

# 随机产生一个大小为6 x 6的线性变换矩阵W_K
W_K = np.random.rand(6, 6)
W_K = np.round(W_K, 2)  # 保留两位小数
print(f"线性变换矩阵W_K为:\n{W_K}")
# 线性变换矩阵W_K为:
# [[0.64 0.81 0.93 0.91 0.82 0.09]
#  [0.36 0.04 0.55 0.8  0.05 0.19]
#  [0.37 0.24 0.8  0.35 0.64 0.49]
#  [0.58 0.94 0.94 0.11 0.84 0.35]
#  [0.1  0.38 0.51 0.96 0.37 0.01]
#  [0.86 0.11 0.48 0.85 0.51 0.45]]

# 随机产生一个大小为6 x 6的线性变换矩阵W_V
W_V = np.random.rand(6, 6)
W_V = np.round(W_V, 2)  # 保留两位小数
print(f"线性变换矩阵W_V为:\n{W_V}")
# 线性变换矩阵W_V为:
# [[0.8  0.02 0.57 0.41 0.99 0.8 ]
#  [0.05 0.19 0.45 0.7  0.33 0.36]
#  [0.92 0.95 0.41 0.9  0.33 0.08]
#  [0.53 0.66 0.89 0.97 0.77 0.76]
#  [0.71 0.7  0.77 0.97 0.37 0.08]
#  [0.24 0.22 0.36 0.81 0.06 0.45]]

Q = x_hat @ W_Q
print(f"Q为:\n{Q}")
# Q为:
# [[1.3544 1.5824 1.7437 2.1496 1.6997 1.4742]
#  [0.576  0.7716 1.4589 2.0357 1.623  1.1929]
#  [0.7484 1.1001 1.3537 1.9311 1.1773 1.1963]
#  [0.7087 0.9811 1.3527 2.07   1.2504 1.2118]]

K = x_hat @ W_K
print(f"K为:\n{K}")
# K为:
# [[1.6389 1.3815 2.2586 2.0598 1.6235 0.8894]
#  [1.5456 1.0069 1.8167 1.9484 1.416  0.7154]
#  [1.1204 1.0166 1.8081 1.5147 1.4635 0.7348]
#  [1.2336 1.0652 1.9015 1.7583 1.3875 0.6707]]

V = x_hat @ W_V
print(f"V为:\n{V}")
# V为:
# [[1.3946 1.4536 2.0187 2.75   1.5005 1.5189]
#  [1.2531 0.7434 1.293  1.811  1.2532 1.311 ]
#  [1.6758 1.4064 1.3476 1.987  1.1564 0.853 ]
#  [1.4869 1.122  1.412  1.9403 1.3396 1.1009]]

Q_1, K_1, V_1 = Q[:, 0:2], K[:, 0:2], V[:, 0:2]
Q_2, K_2, V_2 = Q[:, 2:4], K[:, 2:4], V[:, 2:4]
Q_3, K_3, V_3 = Q[:, 4:6], K[:, 4:6], V[:, 4:6]


# 在应用 softmax 的时候,常见的问题是数值稳定性问题,也就是说,由于可能出现的指数和溢出误差,
# ∑e^(x) 可能会变得非常大。这个溢出误差可以通过用数组的每个值减去其最大值来解决。
def softmax(x):
    max = np.max(x, axis=1, keepdims=True)  # 返回每一行的最大值,并保持维度不变,例如4 x 5 --> 4 x 1,否则就输出一行四个数,不是二维了
    e_x = np.exp(x - max)  # 每一行的所有元素减去这一行的对应最大值
    sum = np.sum(e_x, axis=1, keepdims=True)
    out = e_x / sum
    return out


Q_KT_d_k_1 = Q_1 @ K_1.T / np.sqrt(2)
print(f"Q_KT_d_k_1为: \n{Q_KT_d_k_1}")
# Q_KT_d_k_1为:
# [[3.11537937 2.60687586 2.2105131  2.37330514]
#  [1.42126469 1.1788811  1.01099226 1.08361422]
#  [1.94195628 1.60118513 1.38371535 1.48142601]
#  [1.77970156 1.47307052 1.2667208  1.35716422]]
soft_Q_KT_1 = softmax(Q_KT_d_k_1)
print(f"Softmax result is \n{soft_Q_KT_1}")
# Softmax result is
# [[0.40288203 0.24229119 0.16300445 0.19182233]
#  [0.31628863 0.24820911 0.20984785 0.22565442]
#  [0.34312552 0.2440383  0.19634148 0.2164947 ]
#  [0.3344468  0.24612678 0.20023608 0.21919034]]

out_self_attention_1 = soft_Q_KT_1 @ V_1
print(f"Self attention output 1 is \n{out_self_attention_1}")
# Self attention output 1 is
# [[1.42385785 1.2102227 ]
#  [1.43931553 1.19259007]
#  [1.43526227 1.19922704]
#  [1.43631071 1.1966661 ]]

out_self_attention_2 = softmax(Q_2 @ K_2.T / np.sqrt(2)) @ V_2
print(f"Self attention output 2 is \n{out_self_attention_2}")
# Self attention output 2 is
# [[1.65989199 2.29334469]
#  [1.6423284  2.27141789]
#  [1.63397616 2.26112136]
#  [1.63815253 2.2660779 ]]
out_self_attention_3 = softmax(Q_3 @ K_3.T / np.sqrt(2)) @ V_3
print(f"Self attention output 3 is \n{out_self_attention_3}")
# Self attention output 3 is
# [[1.33149842 1.22979722]
#  [1.32918253 1.2256465 ]
#  [1.32621018 1.22056725]
#  [1.32678984 1.22156985]]
concat_123 = np.concatenate((out_self_attention_1, out_self_attention_2, out_self_attention_3), axis=1)
print(f"Concat attention output is \n{concat_123}")
# Concat attention output is
# [[1.42385785 1.2102227  1.65989199 2.29334469 1.33149842 1.22979722]
#  [1.43931553 1.19259007 1.6423284  2.27141789 1.32918253 1.2256465 ]
#  [1.43526227 1.19922704 1.63397616 2.26112136 1.32621018 1.22056725]
#  [1.43631071 1.1966661  1.63815253 2.2660779  1.32678984 1.22156985]]

W_O = W_V = np.random.rand(6, 6)
W_O = np.round(W_O, 2)  # 保留两位小数
print(f"线性变换矩阵W_O为:\n{W_O}")
# 线性变换矩阵W_O为:
# [[0.81 0.26 0.06 0.24 0.09 0.81]
#  [0.17 0.2  0.81 0.81 0.59 0.91]
#  [0.06 0.96 0.57 0.3  0.83 0.66]
#  [0.99 0.11 0.58 0.47 0.65 0.24]
#  [0.03 0.54 0.36 0.89 0.46 0.42]
#  [0.63 0.53 0.96 0.79 0.5  0.21]]

output = concat_123 @ W_O
print(f"output 为:\n{output}")
# output 为:
# [[4.54378468 3.82881348 5.00193498 5.05441927 4.93795088 4.71804571]
#  [4.52786208 3.8065825  4.9610328  5.0229318  4.89696796 4.69582201]
#  [4.51172342 3.7934082  4.94948667 5.01333192 4.88298697 4.68827983]
#  [4.51794388 3.79856753 4.9539017  5.01639962 4.88902644 4.69119859]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/538095
推荐阅读
相关标签
  

闽ICP备14008679号