当前位置:   article > 正文

【深度学习】Multi-Head Attention 原理与代码实现_multihead attention代码

multihead attention代码

1.Multi-Head Attention 结构

2.Multi-Head Attention 计算

        Multi-Head Attention 是由多个 Self-Attention 组合形成的。对于同一个文本,一个Attention获得一个表示空间,如果多个Attention,则可以获得多个不同的表示空间。基于这种想法,就有了Multi-Head Attention。换句话说,Multi-Head Attention为Attention提供了多个“representation subspaces”。因为在每个Attention中,采用不同的Query / Key / Value权重矩阵,每个矩阵都是随机初始化生成的。然后通过训练,将词嵌入投影到不同的“representation subspaces(表示子空间)”中。 

        将模型分为多个头,形成多个子空间,可以让模型去关注不同方面的信息。上图中Multi-Head Attention 就是将 Scaled Dot-Product Attention 过程做了8次,再把输出合并起来。

3.Multi-Head Attention 代码实现

  1. from math import sqrt
  2. import torch
  3. import torch.nn as nn
  4. class MultiHeadSelfAttention(nn.Module):
  5. dim_in: int # input dimension
  6. dim_k: int # key and query dimension
  7. dim_v: int # value dimension
  8. num_heads: int # number of heads, for each head, dim_* = dim_* // num_heads
  9. def __init__(self, dim_in, dim_k, dim_v, num_heads=8):
  10. super(MultiHeadSelfAttention, self).__init__()
  11. #维度必须能被num_head 整除
  12. assert dim_k % num_heads == 0 and dim_v % num_heads == 0, "dim_k and dim_v must be multiple of num_heads"
  13. self.dim_in = dim_in
  14. self.dim_k = dim_k
  15. self.dim_v = dim_v
  16. self.num_heads = num_heads
  17. #定义线性变换矩阵
  18. self.linear_q = nn.Linear(dim_in, dim_k, bias=False)
  19. self.linear_k = nn.Linear(dim_in, dim_k, bias=False)
  20. self.linear_v = nn.Linear(dim_in, dim_v, bias=False)
  21. self._norm_fact = 1 / sqrt(dim_k // num_heads)
  22. def forward(self, x):
  23. # x: tensor of shape (batch, n, dim_in)
  24. batch, n, dim_in = x.shape
  25. assert dim_in == self.dim_in
  26. nh = self.num_heads
  27. dk = self.dim_k // nh # dim_k of each head
  28. dv = self.dim_v // nh # dim_v of each head
  29. q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
  30. k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
  31. v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2) # (batch, nh, n, dv)
  32. dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact # batch, nh, n, n
  33. dist = torch.softmax(dist, dim=-1) # batch, nh, n, n
  34. att = torch.matmul(dist, v) # batch, nh, n, dv
  35. att = att.transpose(1, 2).reshape(batch, n, self.dim_v) # batch, n, dim_v
  36. return att
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/365005
推荐阅读
相关标签
  

闽ICP备14008679号