赞
踩
import math import torch import collections import numpy as np import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, heads, d_model, dropout=0.1): super().__init__() #输入的特征维度 self.d_model=d_model #每个头的特征维度 self.d_k=d_model//heads #头数 self.h=heads # K Q V三个矩阵,分别是输入通过三个矩阵投影得到的 self.q_linear=nn.Linear(d_model,d_model) self.k_linear=nn.Linear(d_model,d_model) self.v_linear=nn.Linear(d_model,d_model) self.dropout=nn.Dropout(dropout) #输出线性层 self.out=nn.Linear(d_model,d_model) def attention(self,k,q,v,mask=None): #计算query与k的batch转置的矩阵乘法作为得分。来给value施加注意力 #进行缩放,防止送入soft后梯度消失 scores=torch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.d_k) if mask is not None: scores=scores.masked_fill(mask==0,float('-inf')) #对score进行softmax score=torch.nn.functional.softmax(scores,dim=-1) score=self.dropout(score) output=torch.matmul(score,v) return output def forward(self,q,k,v,mask=None): batch_size=q.shape[0] #转换成(batch个 head个 序列长度 特征维度)的张量 q=self.q_linear(q).view(batch_size,-1,self.h,self.d_k).transpose(1,2) k=self.k_linear(k).view(batch_size,-1,self.h,self.d_k).transpose(1,2) v=self.v_linear(v).view(batch_size,-1,self.h,self.d_k).transpose(1,2) score=self.attention(k,q,v,mask) concat=score.transpose(1,2).contiguous().view(batch_size,-1,self.h*self.d_k) output=self.out(concat) return output if __name__ == '__main__': heads=4 d_model=128 dropout=0.1 model=MultiHeadAttention(heads,d_model,dropout) batch_size=2 seq_len=5 q=torch.rand(batch_size,seq_len,d_model) k=torch.rand(batch_size,seq_len,d_model) v=torch.rand(batch_size,seq_len,d_model) output=model.forward(q,k,v) loss=output.mean() loss.backward()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。