赞
踩
Multi-Head Attention 是由多个 Self-Attention 组合形成的。对于同一个文本,一个Attention获得一个表示空间,如果多个Attention,则可以获得多个不同的表示空间。基于这种想法,就有了Multi-Head Attention。换句话说,Multi-Head Attention为Attention提供了多个“representation subspaces”。因为在每个Attention中,采用不同的Query / Key / Value权重矩阵,每个矩阵都是随机初始化生成的。然后通过训练,将词嵌入投影到不同的“representation subspaces(表示子空间)”中。
将模型分为多个头,形成多个子空间,可以让模型去关注不同方面的信息。上图中Multi-Head Attention 就是将 Scaled Dot-Product Attention 过程做了8次,再把输出合并起来。
- from math import sqrt
-
- import torch
- import torch.nn as nn
-
- class MultiHeadSelfAttention(nn.Module):
- dim_in: int # input dimension
- dim_k: int # key and query dimension
- dim_v: int # value dimension
- num_heads: int # number of heads, for each head, dim_* = dim_* // num_heads
-
- def __init__(self, dim_in, dim_k, dim_v, num_heads=8):
- super(MultiHeadSelfAttention, self).__init__()
- #维度必须能被num_head 整除
- assert dim_k % num_heads == 0 and dim_v % num_heads == 0, "dim_k and dim_v must be multiple of num_heads"
- self.dim_in = dim_in
- self.dim_k = dim_k
- self.dim_v = dim_v
- self.num_heads = num_heads
- #定义线性变换矩阵
- self.linear_q = nn.Linear(dim_in, dim_k, bias=False)
- self.linear_k = nn.Linear(dim_in, dim_k, bias=False)
- self.linear_v = nn.Linear(dim_in, dim_v, bias=False)
- self._norm_fact = 1 / sqrt(dim_k // num_heads)
-
- def forward(self, x):
- # x: tensor of shape (batch, n, dim_in)
- batch, n, dim_in = x.shape
- assert dim_in == self.dim_in
-
- nh = self.num_heads
- dk = self.dim_k // nh # dim_k of each head
- dv = self.dim_v // nh # dim_v of each head
-
- q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
- k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
- v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2) # (batch, nh, n, dv)
-
- dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact # batch, nh, n, n
- dist = torch.softmax(dist, dim=-1) # batch, nh, n, n
-
- att = torch.matmul(dist, v) # batch, nh, n, dv
- att = att.transpose(1, 2).reshape(batch, n, self.dim_v) # batch, n, dim_v
- return att
data:image/s3,"s3://crabby-images/deb9d/deb9d52e6c78f73fbfaadc6e519fd00d286664e1" alt=""
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。