赞
踩
下图是Transform 的整体架构,由decoder和encoder构成,下面分部分对Transform的结构进行分析 (下图来自于论文Attention is all you need)。
(论文链接https://arxiv.org/pdf/1706.03762.pdf)
Encoder 主要是用来对句子的输入进行编码,下面用 ”我爱学习“ 这句话为例子解析编码过程。
首先是以词向量的形式进行输入,并且这里的词向量加入了positional encoding,也就是位置信息,来标定 ’我‘ ’爱‘ ’学‘ ’习‘ 这四个词向量的位置。
下一步就是将融合了位置信息的词向量输入到self-Attention 中进行编码
Masked 的意思即为在产生第n个编码的时候只能考虑第n个和第n个之前的信息,不能考虑之后的信息。
注:只画出了V的输入关系。
import torch import torchvision class Layernorm_m(torch.nn.Module): def __init__(self): super(Layernorm_m,self).__init__() pass def forward(self,x): mean = torch.mean(x, dim = 2) std = torch.std(x, dim = 2) return (x - mean[:, :, None]) / std[:,:,None] class Attention(torch.nn.Module): def __init__(self): super(Attention,self).__init__() self.Wq = torch.nn.Linear(512,512,bias= False) self.Wk = torch.nn.Linear(512, 512,bias= False) self.Wv = torch.nn.Linear(512, 512,bias= False) self.fc = torch.nn.Linear(512, 512,bias= False) self.layernorm = Layernorm_m() def forward(self,x): res = x q = self.Wq(x) k = self.Wk(x) v = self.Wv(x) #q* k.T * v A = q.bmm(k.permute(0,2,1)) / torch.sqrt(torch.tensor(512,dtype = torch.float32)) A = torch.softmax(A, dim = -1) x = A.bmm(v) x = self.fc(x) return self.layernorm(x + res) class PoswiseFeedForwardNet(torch.nn.Module): def __init__(self): super(PoswiseFeedForwardNet,self).__init__() self.fc = torch.nn.Linear(512,512) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(512,512) self.layerNorm = Layernorm_m() def forward(self,x): res = x x = self.fc(x) x = self.relu(x) x = self.fc2(x) return self.layerNorm(x + res) class Encoder(torch.nn.Module): def __init__(self): super(Encoder,self).__init__() self.attention = Attention() self.ffn = PoswiseFeedForwardNet() def forward(self,x): x = self.attention(x) x = self.ffn(x) return x x = torch.randn((4,16,512)) encoder = Encoder() x = encoder(x) print(x) pass
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。