当前位置:   article > 正文

Transformer——多头注意力机制(Pytorch)

Transformer——多头注意力机制(Pytorch)

1. 原理图

2. 代码

  1. import torch
  2. import torch.nn as nn
  3. class Multi_Head_Self_Attention(nn.Module):
  4. def __init__(self, embed_size, heads):
  5. super(Multi_Head_Self_Attention, self).__init__()
  6. self.embed_size = embed_size
  7. self.heads = heads
  8. self.head_dim = embed_size // heads
  9. self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
  10. self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
  11. self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
  12. self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)
  13. def forward(self,queries, keys, values, mask):
  14. N = queries.shape[0] # batch_size
  15. query_len = queries.shape[1] # sequence_length
  16. key_len = keys.shape[1] # sequence_length
  17. value_len = values.shape[1] # sequence_length
  18. queries = self.queries(queries)
  19. keys = self.keys(keys)
  20. values = self.values(values)
  21. # Split the embedding into self.heads pieces
  22. # batch_size, sequence_length, embed_size(512) -->
  23. # batch_size, sequence_length, heads(8), head_dim(64)
  24. queries = queries.reshape(N, query_len, self.heads, self.head_dim)
  25. keys = keys.reshape(N, key_len, self.heads, self.head_dim)
  26. values = values.reshape(N, value_len, self.heads, self.head_dim)
  27. # batch_size, sequence_length, heads(8), head_dim(64) -->
  28. # batch_size, heads(8), sequence_length, head_dim(64)
  29. queries = queries.transpose(1, 2)
  30. keys = keys.transpose(1, 2)
  31. values = values.transpose(1, 2)
  32. # Scaled dot-product attention
  33. score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))
  34. if mask is not None:
  35. score = score.masked_fill(mask == 0, float("-inf"))
  36. # batch_size, heads(8), sequence_length, sequence_length
  37. attention = torch.softmax(score, dim=-1)
  38. out = torch.matmul(attention, values)
  39. # batch_size, heads(8), sequence_length, head_dim(64) -->
  40. # batch_size, sequence_length, heads(8), head_dim(64) -->
  41. # batch_size, sequence_length, embed_size(512)
  42. # 为了方便送入后面的网络
  43. out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
  44. out = self.fc_out(out)
  45. return out
  46. batch_size = 64
  47. sequence_length = 10
  48. embed_size = 512
  49. heads = 8
  50. mask = None
  51. Q = torch.randn(batch_size, sequence_length, embed_size)
  52. K = torch.randn(batch_size, sequence_length, embed_size)
  53. V = torch.randn(batch_size, sequence_length, embed_size)
  54. model = Multi_Head_Self_Attention(embed_size, heads)
  55. output = model(Q, K, V, mask)
  56. print(output.shape)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/816574
推荐阅读
相关标签
  

闽ICP备14008679号