赞
踩
- import torch
- import torch.nn as nn
- import numpy as np
-
-
- # TODO MHA
- def setup_seed(seed):
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- # 设置随机数种子
- setup_seed(20)
-
- Q = torch.tensor([[1]], dtype=torch.float32) # [2, 3, 4]
- K = torch.tensor([[3]], dtype=torch.float32) # [2, 5, 4]
- V = torch.tensor([[5]], dtype=torch.float32) # [2, 5, 4]
-
- multiHead = nn.MultiheadAttention(1, 1)
- att_o, att_o_w = multiHead(Q, K, V)
-
- ################################
-
- # 复现 Multi-head Attention
- w = multiHead.in_proj_weight
- b = multiHead.in_proj_bias
- w_o = multiHead.out_proj.weight
- b_o = multiHead.out_proj.bias
-
- w_q, w_k, w_v = w.chunk(3)
- b_q, b_k, b_v = b.chunk(3)
-
- # Q、K、V的映射
- q = Q @ w_q + b_q
- k = K @ w_k + b_k
- v = V @ w_v + b_v
- dk = q.shape[-1]
- # 注意力权重的计算
- softmax_2 = torch.nn.Softmax(dim=-1)
- att_o_w2 = softmax_2(q @ k.transpose(-2, -1) / np.sqrt(dk))
- # 输出
- out = att_o_w * v
- # 输出映射
- att_o2 = out @ w_o + b_o
- print(att_o, att_o_w)
- print(att_o2, att_o_w2)
- pass
输出结果
tensor([[-0.4038]], grad_fn=<SqueezeBackward1>) tensor([[1.]], grad_fn=<SqueezeBackward1>)
tensor([[-0.4038]], grad_fn=<AddBackward0>) tensor([[1.]], grad_fn=<SoftmaxBackward0>)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。