当前位置:   article > 正文

PyTorch中torch.nn.MultiheadAttention()的实现(一维情况下)_torch自带的nn.multiheadattention复现

torch自带的nn.multiheadattention复现
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. # TODO MHA
  5. def setup_seed(seed):
  6. torch.manual_seed(seed)
  7. torch.cuda.manual_seed_all(seed)
  8. np.random.seed(seed)
  9. torch.backends.cudnn.deterministic = True
  10. # 设置随机数种子
  11. setup_seed(20)
  12. Q = torch.tensor([[1]], dtype=torch.float32) # [2, 3, 4]
  13. K = torch.tensor([[3]], dtype=torch.float32) # [2, 5, 4]
  14. V = torch.tensor([[5]], dtype=torch.float32) # [2, 5, 4]
  15. multiHead = nn.MultiheadAttention(1, 1)
  16. att_o, att_o_w = multiHead(Q, K, V)
  17. ################################
  18. # 复现 Multi-head Attention
  19. w = multiHead.in_proj_weight
  20. b = multiHead.in_proj_bias
  21. w_o = multiHead.out_proj.weight
  22. b_o = multiHead.out_proj.bias
  23. w_q, w_k, w_v = w.chunk(3)
  24. b_q, b_k, b_v = b.chunk(3)
  25. # Q、K、V的映射
  26. q = Q @ w_q + b_q
  27. k = K @ w_k + b_k
  28. v = V @ w_v + b_v
  29. dk = q.shape[-1]
  30. # 注意力权重的计算
  31. softmax_2 = torch.nn.Softmax(dim=-1)
  32. att_o_w2 = softmax_2(q @ k.transpose(-2, -1) / np.sqrt(dk))
  33. # 输出
  34. out = att_o_w * v
  35. # 输出映射
  36. att_o2 = out @ w_o + b_o
  37. print(att_o, att_o_w)
  38. print(att_o2, att_o_w2)
  39. pass

输出结果

tensor([[-0.4038]], grad_fn=<SqueezeBackward1>) tensor([[1.]], grad_fn=<SqueezeBackward1>)
tensor([[-0.4038]], grad_fn=<AddBackward0>) tensor([[1.]], grad_fn=<SoftmaxBackward0>)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/346718?site
推荐阅读
相关标签
  

闽ICP备14008679号