赞
踩
Transformer代码结构
import math import torch import numpy as np import torch.nn as nn import torch.optim as optim import torch.utils.data as Data import matplotlib as mpl import matplotlib.pyplot as plt %matplotlib inline from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import _LRScheduler # S: Symbol that shows starting of decoding input # E: Symbol that shows starting of decoding output # P: Symbol that will fill in blank sequence if current batch data size is short than time steps sentences = [ # enc_input dec_input dec_output ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'], # 德语 S是解码端的输入,E不是解码端的输出,而是解码端的真实标签,它和最后的输出做loss ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E'] # 法语 ] # Padding Should be Zero src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5} src_vocab_size = len(src_vocab) # 6 tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8} idx2word = {i: w for i, w in enumerate(tgt_vocab)} tgt_vocab_size = len(tgt_vocab) #9 src_len = 5 # enc_input max sequence length tgt_len = 6 # dec_input(=dec_output) max sequence length # Transformer Parameters d_model = 512 # Embedding Size d_ff = 2048 # FeedForward dimension d_k = d_v = 64 # dimension of K(=Q), V n_layers = 6 # number of Encoder of Decoder Layer n_heads = 8 # number of heads in Multi-Head Attention
def make_data(sentences): enc_inputs, dec_inputs, dec_outputs = [], [], [] for i in range(len(sentences)): enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]] dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]] dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]] # extend的作用是将一个列表当中的元素加到另一个列表末尾 enc_inputs.extend(enc_input) dec_inputs.extend(dec_input) dec_outputs.extend(dec_output) return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs) enc_inputs, dec_inputs, dec_outputs = make_data(sentences) class MyDataSet(Data.Dataset): def __init__(self, enc_inputs, dec_inputs, dec_outputs): super(MyDataSet, self).__init__() self.enc_inputs = enc_inputs self.dec_inputs = dec_inputs self.dec_outputs = dec_outputs def __len__(self): return self.enc_inputs.shape[0] def __getitem__(self, idx): return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx] loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True) # batchsize = 2, shuffle = True
enc, dec, dec = make_data(sentences)
print(enc)
print(enc[0])
print(enc.shape)
print(enc.shape[0])
tensor([[1, 2, 3, 4, 0],
[1, 2, 3, 5, 0]])
tensor([1, 2, 3, 4, 0])
torch.Size([2, 5])
2
idx2word
{0: 'P',
1: 'i',
2: 'want',
3: 'a',
4: 'beer',
5: 'coke',
6: 'S',
7: 'E',
8: '.'}
emb = nn.Embedding(src_vocab_size, d_model)
enc_o = emb(enc) # [batch_size, src_len, d_model]
print(enc_o.shape)
print(enc_o)
torch.Size([2, 5, 512])
tensor([[[ 0.2115, 0.1797, 0.3934, ..., -1.3825, 0.4188, -0.5924],
[ 0.8305, -0.1925, -0.4087, ..., -1.2020, 1.5736, 0.8208],
[-2.3016, 0.3060, -0.2981, ..., 1.8724, -0.3179, -0.6690],
[ 0.6748, 0.2370, -1.0590, ..., -0.2914, -0.0615, 0.2832],
[-0.9452, -0.1783, 0.4750, ..., 0.0894, 2.0903, -1.8880]],
[[ 0.2115, 0.1797, 0.3934, ..., -1.3825, 0.4188, -0.5924],
[ 0.8305, -0.1925, -0.4087, ..., -1.2020, 1.5736, 0.8208],
[-2.3016, 0.3060, -0.2981, ..., 1.8724, -0.3179, -0.6690],
[ 0.5548, -0.0924, 0.0213, ..., -1.6186, 1.0486, -0.0319],
[-0.9452, -0.1783, 0.4750, ..., 0.0894, 2.0903, -1.8880]]],
grad_fn=<EmbeddingBackward0>)
def plot_position_embedding(position):
plt.pcolormesh(position[0], cmap = 'RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.colorbar()
plt.show()
# enc_o = enc_o.detach().numpy()
# plot_position_embedding(enc_o)
# 3.位置编码的实现 # PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # unsqueeze(1) 操作增加了一个维度,使得张量的形状从(max_len,)变为(max_len, 1)。 div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # torch.arange(0, d_model, 2).float() 这部分代码创建一个从0到d_model(不包括d_model)的张量,步长为2,并且将所有元素转换为浮点数类型。 # (-math.log(10000.0) / d_model) 这部分代码计算一个标量值,它是-log(10000.0)除以d_model的结果。torch.exp()函数对上述张量中的每个元素应用指数运算。 pe[:, 0::2] = torch.sin(position * div_term) # 选取所有的偶数列 pe[:, 1::2] = torch.cos(position * div_term) # 选取所有的奇数列 pe = pe.unsqueeze(0).transpose(0, 1) # [5000,512]-->[1,5000,512]-->[5000,1,512] # pe.unsqueeze(0)这部分代码在pe这个张量的每个元素前面添加一个维度。这样,原来的pe的形状可能会从(N,)变为(1,N),其中N是pe的元素个数。 # transpose(0, 1)这部分代码将经过unsqueeze操作后的张量的维度进行转置。换句话说,它交换了新添加的维度和原来的第一个维度的位置。因此,最后的结果是,pe的形状从(1,N)变为(N,1)。 self.register_buffer('pe', pe) # register_buffer是一个方法,定义一个缓冲区,通常在nn.Module类中使用,用于在模型中添加一个持久的(即不会在反向传播时被清除)张量。这通常用于存储一些在多次前向/后向传播中需要保持的数据。 def forward(self, x): ''' x: [seq_len, batch_size, d_model] ''' x = x + self.pe[:x.size(0), :] # x.size(0)代表输入序列的长度 return self.dropout(x) # 4.在得到的attention score矩阵(这是一个对称矩阵)中,pad部分也存在一个score值,如何消去这个pad值呢,可以使用一个符号标记矩阵,将pad填充的部分设置为1,其余正常值部分设置为0,然后将标记为1 # 的位置的地方的值消去(这里是在计算softmax之前把这里设置为 负 无穷大) def get_attn_pad_mask(seq_q, seq_k): ''' seq_q: [batch_size, seq_len] seq_k: [batch_size, seq_len] seq_len could be src_len or it could be tgt_len seq_len in seq_q and seq_len in seq_k maybe not equal ''' # seq_q和seq_k的值不一定一致,在交互注意力,q来自解码端,k来自编码端,所以告诉模型这边pad符号信息就可以,解码端的pad信息在交互注意力层没有用到?? batch_size, len_q = seq_q.size() batch_size, len_k = seq_k.size() # eq(zero) is PAD token # eq(0) 是用于对序列中的元素进行零值判断,生成一个与 seq_k 相同大小的张量,其中填充位置对应的元素被标记为True或False。 pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], true is masked # expand(batch_size, len_q, len_k) 是在给定维度上对该张量进行扩展的操作。这行代码的作用可能是将 pad_attn_mask 进行扩展,使其在 batch_size、len_q 和 len_k # 这三个维度上扩展到指定的大小。 return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k] ''' seq2seq和transformer两者的"预测"都是自回归,当前神经元的输入是上一个神经元的输出 但是在训练时transformer与seq2seq有不同 因为seq2seq在训练时也是一个一个的输入到神经元,是自回归的,当前神经元的输入受上一个神经元输出的影响。 但是transformer在训练时是并行的,一次性将所有数据全部输入进去,所以为了达到更好的效果,我们就用mask将后面的数据进行’遮挡‘ 这个mask就应该是上三角矩阵,上三角元素全为1,主对角线全为0,方便之后乘上一个无穷大的数 ''' def get_attn_subsequence_mask(seq): ''' seq: [batch_size, tgt_len] ''' attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # np.ones(attn_shape):创建了一个形状为 attn_shape 的全为1的矩阵。这个矩阵将用作上三角矩阵的基础 # np.triu(..., k=1):使用 np.triu() 函数获取输入矩阵的上三角部分 # 参数 k=1 表示将主对角线以下的第一条对角线设为0,以此类推,即保留主对角线及其以上的部分,并将其他部分设为0。 subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix # 将一个 NumPy 数组转换为 PyTorch 张量,并将其类型转换为 byte 类型 subsequence_mask = torch.from_numpy(subsequence_mask).byte() return subsequence_mask # [batch_size, tgt_len, tgt_len]
get_attn_subsequence_mask
# 前面看不到后面的padding,矩阵下面全部为0
# 在mask里,应该被忽略的我们会设成1,应该被保留的会设成0
# 计算的时候,把1的部分设置成一个超级小的数,然后在计算softmax的时候,一个超级小的数的指数会无限接近与0。也就是它对应的attention的权重就是0了,
x = torch.tensor([[7, 6, 0, 0, 0], [1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
get_attn_subsequence_mask(x)
tensor([[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]], [[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]], [[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]]], dtype=torch.uint8)
PositionalEncoding
position_embedding = PositionalEncoding(d_model)
input_tensor = torch.zeros(1, 50, 512)
position = position_embedding.forward(input_tensor)
print(position)
print(position.shape)
tensor([[[0.0000, 1.1111, 0.0000, ..., 1.1111, 0.0000, 1.1111],
[0.0000, 1.1111, 0.0000, ..., 1.1111, 0.0000, 1.1111],
[0.0000, 0.0000, 0.0000, ..., 1.1111, 0.0000, 1.1111],
...,
[0.0000, 1.1111, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 1.1111, 0.0000, ..., 1.1111, 0.0000, 1.1111],
[0.0000, 1.1111, 0.0000, ..., 1.1111, 0.0000, 1.1111]]])
torch.Size([1, 50, 512])
#这个图和我们原理中展示的横纵坐标是颠倒的
def plot_position_embedding(position):
plt.pcolormesh(position[0], cmap = 'RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
plot_position_embedding(position)
pe1 = torch.zeros(5000, 512)
position1 = torch.arange(0, 5000, dtype=torch.float).unsqueeze(1) # unsqueeze(1) 操作增加了一个维度,使得张量的形状从(max_len,)变为(max_len, 1)。
div_term1 = torch.exp(torch.arange(0, 512, 2).float() * (-math.log(10000.0) / 512))
pe1[:, 0::2] = torch.sin(position1 * div_term1)
pe1
tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 8.4147e-01, 0.0000e+00, 8.2186e-01, ..., 0.0000e+00,
1.0366e-04, 0.0000e+00],
[ 9.0930e-01, 0.0000e+00, 9.3641e-01, ..., 0.0000e+00,
2.0733e-04, 0.0000e+00],
...,
[ 9.5625e-01, 0.0000e+00, 9.3594e-01, ..., 0.0000e+00,
4.9515e-01, 0.0000e+00],
[ 2.7050e-01, 0.0000e+00, 8.2251e-01, ..., 0.0000e+00,
4.9524e-01, 0.0000e+00],
[-6.6395e-01, 0.0000e+00, 1.4615e-03, ..., 0.0000e+00,
4.9533e-01, 0.0000e+00]])
pe1[:, 1::2] = torch.cos(position1 * div_term1)
pe1
tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 8.2186e-01, ..., 1.0000e+00,
1.0366e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.3641e-01, ..., 1.0000e+00,
2.0733e-04, 1.0000e+00],
...,
[ 9.5625e-01, -2.9254e-01, 9.3594e-01, ..., 8.5926e-01,
4.9515e-01, 8.6881e-01],
[ 2.7050e-01, -9.6272e-01, 8.2251e-01, ..., 8.5920e-01,
4.9524e-01, 8.6876e-01],
[-6.6395e-01, -7.4778e-01, 1.4615e-03, ..., 8.5915e-01,
4.9533e-01, 8.6871e-01]])
pe1.shape
torch.Size([1, 5000, 512])
pe1 = pe1.reshape(1,5000,512)
pe1.shape
torch.Size([1, 5000, 512])
pe1.reshape(1, 5000, 512)
plot_position_embedding(pe1)
get_attn_pad_mask
inputs = torch.tensor([[7, 6, 0, 0, 0], [1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
get_attn_pad_mask(inputs, inputs)
tensor([[[False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True]], [[False, False, False, True, True], [False, False, False, True, True], [False, False, False, True, True], [False, False, False, True, True], [False, False, False, True, True]], [[False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True], [False, False, True, True, True]]])
inputs.shape
torch.Size([3, 5])
batch, len1 = inputs.size()
print(batch)
print(len1)
# 7.缩放点积注意力 # q是query,k,v代表k和value,q和k做完矩阵乘法后,做mask class ScaledDotProductAttention(nn.Module): def __init__(self): super(ScaledDotProductAttention, self).__init__() def forward(self, Q, K, V, attn_mask): ''' Q: [batch_size, n_heads, len_q, d_k] K: [batch_size, n_heads, len_k, d_k] V: [batch_size, n_heads, len_v(=len_k), d_v] attn_mask: [batch_size, n_heads, seq_len, seq_len] ''' scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k] # 把被mask的地方设置为无限小,使得softmax之后基本为0 scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True. # dim=-1 表示对最后一个维度进行操作,即对每行内部进行 Softmax 运算。 '''例: [[1.0, 2.0, 3.0, 4.0], [[0.0321, 0.0871, 0.2369, 0.6439], [5.0, 6.0, 7.0, 8.0], [0.0321, 0.0871, 0.2369, 0.6439], [9.0, 10.0, 11.0, 12.0]] [0.0321, 0.0871, 0.2369, 0.6439]] ''' attn = nn.Softmax(dim=-1)(scores) context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v] return context, attn # 6.多头注意力机制的实现 class MultiHeadAttention(nn.Module): def __init__(self): super(MultiHeadAttention, self).__init__() # 输入进来的QKV矩阵是相等的,我们会使用linear做一个映射得到参数矩阵Wq,Wk,Wv self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False) self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) def forward(self, input_Q, input_K, input_V, attn_mask): ''' input_Q: [batch_size, len_q, d_model] input_K: [batch_size, len_k, d_model] input_V: [batch_size, len_v(=len_k), d_model] attn_mask: [batch_size, seq_len, seq_len] ''' # 分头的步骤,首先映射分头,然后计算atten_scores,然后计算atten_value residual, batch_size = input_Q, input_Q.size(0) # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W) # 这里先映射,后分头,需要注意的是q和k的维度要保持一致(因为q和k要计算内积),这里它们两个的维度都是d_k Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # Q: [batch_size, n_heads, len_q, d_k] # self.W_Q(input_Q)表示将输入的查询input_Q通过之前定义的W_Q线性层进行线性变换。 # 然后,.view(batch_size, -1, n_heads,d_k)的作用是将得到的结果重新塑造成一个新的形状, # 其中batch_size表示批量大小,-1表示自动推断该维度的大小,n_heads表示注意力头的数量,d_k表示每个注意力头的维度。 # 这一步通常是为了将线性变换后的结果准备成适合进行多头注意力计算的形状。 K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # K: [batch_size, n_heads, len_k, d_k] V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2) # V: [batch_size, n_heads, len_v(=len_k), d_v] # .repeat(1, n_heads, 1, 1)表示沿着各个维度复制数据,具体来说,第一个参数1表示不复制,n_heads表示复制n_heads次,后面两个1表示不复制。 # 这样可以将刚刚增加的维度进行复制,使得形状变为(batch_size, n_heads, seq_length, seq_length),确保每个注意力头都可以使用相同的掩码 attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len] # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k] # 然后进行缩放点积注意力计算 7 context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask) # context: [batch_size, n_heads, len_q, d_v] context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v] output = self.fc(context) # [batch_size, len_q, d_model] return nn.LayerNorm(d_model).cuda()(output + residual), attn class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), # d_model = 512, d_ff = 2048 nn.ReLU(), nn.Linear(d_ff, d_model, bias=False) ) def forward(self, inputs): ''' inputs: [batch_size, seq_len, d_model] ''' residual = inputs output = self.fc(inputs) return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model] # 5.EncoderLayer:包含两个部分,多头注意力机制和前馈神经网络 class EncoderLayer(nn.Module): def __init__(self): super(EncoderLayer, self).__init__() self.enc_self_attn = MultiHeadAttention() # 多头自注意力层 self.pos_ffn = PoswiseFeedForwardNet() # 全连接的前馈神经网络 def forward(self, enc_inputs, enc_self_attn_mask): ''' enc_inputs: [batch_size, src_len, d_model] enc_self_attn_mask: [batch_size, src_len, src_len] ''' # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len] # 看一下多头注意力机制的实现 6 enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model] return enc_outputs, attn class DecoderLayer(nn.Module): def __init__(self): super(DecoderLayer, self).__init__() self.dec_self_attn = MultiHeadAttention() self.dec_enc_attn = MultiHeadAttention() self.pos_ffn = PoswiseFeedForwardNet() def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): ''' dec_inputs: [batch_size, tgt_len, d_model] enc_outputs: [batch_size, src_len, d_model] dec_self_attn_mask: [batch_size, tgt_len, tgt_len] dec_enc_attn_mask: [batch_size, tgt_len, src_len] ''' # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len] dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len] dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model] return dec_outputs, dec_self_attn, dec_enc_attn # 2.Encoder包含三个部分:词向量embedding,位置编码部分,注意力层及后续的前馈神经网络 class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.src_emb = nn.Embedding(src_vocab_size, d_model) # 这里其实就是定义生成一个矩阵,大小为src_vocab_size*d_model,src_vocab_size是源词表当中所有单词的个数,这里是6 self.pos_emb = PositionalEncoding(d_model) # 位置编码,这里使用的是固定的正余弦函数,也可以使用类似词向量的nn.Embedding获得一个可以更新学习的位置编码 self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) # 使用ModuleList对多个encoder进行堆叠因为后续的encoder并没有使用词向量和位置编码,所以抽离出来 def forward(self, enc_inputs): ''' enc_inputs: [batch_size, src_len] ''' # 这里通过src_emb,进行索引定位,enc_ouputs的输出形状为[batch_size,src_len,d_model],就是将输入句子中的每个词转化为词向量 enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model] # 位置编码,把两者相加放入到了这个函数里面,这里可以去看看这个位置编码的函数3 enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model] # get_attn_pad_mask是为了得到句子中pad的位置信息,给到模型后面,在计算自注意力和交互注意力的时候去掉pad的符号影响,转到函数4 enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len] enc_self_attns = [] for layer in self.layers: # 去看EncoderLayer层 5 # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len] enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) enc_self_attns.append(enc_self_attn) return enc_outputs, enc_self_attns # 9.Decoder class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) self.pos_emb = PositionalEncoding(d_model) self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) def forward(self, dec_inputs, enc_inputs, enc_outputs): ''' dec_inputs: [batch_size, tgt_len] enc_intpus: [batch_size, src_len] enc_outputs: [batsh_size, src_len, d_model] ''' dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model] dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model] # get_attn_pad_mask 自注意力层的时候的pad部分 dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len] # get_attn_subsequence_mask 就是当前单词之后的看不到,使用一个上三角为1的矩阵 dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len] # 两个矩阵相加,大于0的为1,不大于0的为0,为1的之后就会被fill成负无穷大 dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len] dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len] dec_self_attns, dec_enc_attns = [], [] for layer in self.layers: # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len] dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) dec_self_attns.append(dec_self_attn) dec_enc_attns.append(dec_enc_attn) return dec_outputs, dec_self_attns, dec_enc_attns # 1.从整体网络结构来看,分为三个部分:编码层,解码层,输出层 class Transformer(nn.Module): def __init__(self): super(Transformer, self).__init__() self.encoder = Encoder().cuda() # 编码层 self.decoder = Decoder().cuda() # 解码层 self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda() # 输出层,d_model是我们解码层每个token输出的维度大小,之后会做一个tgt_vocab_size(9)大小的softmax # 输出层的作用就是将512维的输出映射到tgt_vocab_size的维度,然后做一个softmax,来预测输出的应该是哪一个词 def forward(self, enc_inputs, dec_inputs): ''' enc_inputs: [batch_size, src_len] dec_inputs: [batch_size, tgt_len] ''' # 这里有两个数据进行输入,一个是enc_inputs,形状为[batch_size,src_len],主要是作为编码端的输入,一个是dec_inputs,形状为[batch_size,tgt_len],主要作为解码端的输入 # tensor to store decoder outputs # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device) # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len] # enc_inputs作为输入 形状为[batch_size,src_len],输出由自己的函数内部决定 # enc_outputs是主要的输出,enc_self_attns是QK矩阵转置相乘之后softmax之后的矩阵值,代表的是每个单词和其他单词的相关性,好像主要是为了可视化 enc_outputs, enc_self_attns = self.encoder(enc_inputs) # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len] # decoder的主要输入由两个部分,一是decoder的输入(模型图中的outputs),另一个是encoder的输出(cross attention部分) # dec_outputs是decoder的主要输出,用户后续的linear映射,dec_self_attns类比于enc_self_attns是查看每个单词对decoder中输入的其他单词的相关性,dec_enc_attns是decoder中每个单词对encoder中每个单词的相关性 dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) # dec_outputs做映射到词表大小 dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size] return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns model = Transformer().cuda() # 在写模型的时候遵循两个规则 # 1.构建模型的时候要先从整体到局部,先把大的框架搭起来,再去完善细节部分 # 2.一定要搞清楚数据的流动形状,就是经过某个模型,要清楚输入是什么形状,输出是什么形状(知道输出是什么形状,就可以知道一部分的输入是什么形状,该怎么写代码) criterion = nn.CrossEntropyLoss(ignore_index=0) # optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
多头注意力机制的验证
i = torch.rand(1, 60)
print(i.shape)
i = torch.rand(1, 60)
print(i.shape)
i = torch.rand(1, 60).to('cuda')
print(i.shape)
a = get_attn_pad_mask(i, i).to('cuda')
temp_mha = MultiHeadAttention().to('cuda')
#创建一份虚拟数据
y = torch.rand(1, 60, 512).to('cuda')
#开始计算,把y既当q,又当k,v
output, attn = temp_mha.forward(y, y, y, attn_mask = a)
print(output.shape)
print(attn.shape)
torch.Size([1, 60])
torch.Size([1, 60, 512])
torch.Size([1, 8, 60, 60])
自定义学习率调度程序配合优化器
效果更好
class CustomizedSchedule1:
def __init__(self, d_model, warmup_steps=4000):
self.d_model = d_model
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = (step + 1) ** -0.5 # 避免出现 0 的负指数
arg2 = step * (self.warmup_steps ** (-1.5))
arg3 = (self.d_model ** -0.5)
return arg3 * min(arg1, arg2)
learning_rate_fn = CustomizedSchedule1(d_model)
# 创建优化器和学习率调度器
optimizer1 = torch.optim.Adam(model.parameters(), lr=0) # 初始学习率设为 0,后续由学习率调度器控制
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer1, lr_lambda=learning_rate_fn)
class CustomizedSchedule1:
def __init__(self, d_model, warmup_steps=4000):
self.d_model = d_model
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = (step + 1) ** -0.5 # 避免出现 0 的负指数
arg2 = step * (self.warmup_steps ** (-1.5))
arg3 = (self.d_model ** -0.5)
return arg3 * min(arg1, arg2)
learning_rate_fn = CustomizedSchedule1(d_model)
# 创建优化器和学习率调度器
optimizer1 = torch.optim.Adam(model.parameters(), lr=0) # 初始学习率设为 0,后续由学习率调度器控制
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer1, lr_lambda=learning_rate_fn)
# 这个图就可以看出,受warm_steps的影响,在达到warm_steps的步数时会缓慢下降(训练一个batch_size就是一步)
temp_learning_rate_schedule = CustomizedSchedule1(d_model)
#下面是学习率的设计图
# 绘制学习率随训练步数变化的曲线
steps = torch.arange(40000, dtype=torch.float32)
learning_rates = [temp_learning_rate_schedule(step).item() for step in steps]
plt.plot(steps.numpy(), learning_rates)
plt.ylabel("Learning rate")
plt.xlabel("Train step")
plt.show()
训练
for epoch in range(1000): for enc_inputs, dec_inputs, dec_outputs in loader: ''' enc_inputs: [batch_size, src_len] dec_inputs: [batch_size, tgt_len] dec_outputs: [batch_size, tgt_len] ''' enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # outputs: [batch_size * tgt_len, tgt_vocab_size] outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) # view(-1)的作用是将dec_outputs转换为一维张量,即将其所有的元素都拉平成一个一维向量。 loss = criterion(outputs, dec_outputs.view(-1)) print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss)) optimizer.zero_grad() loss.backward() optimizer.step() # 调用学习率调度器更新学习率 scheduler.step()
Epoch: 0001 loss = 2.700844 Epoch: 0002 loss = 2.676570 Epoch: 0003 loss = 2.627945 Epoch: 0004 loss = 2.606202 Epoch: 0005 loss = 2.650645 Epoch: 0006 loss = 2.588106 Epoch: 0007 loss = 2.445351 Epoch: 0008 loss = 2.508237 Epoch: 0009 loss = 2.447822 Epoch: 0010 loss = 2.325812 Epoch: 0011 loss = 2.384464 Epoch: 0012 loss = 2.271541 Epoch: 0013 loss = 2.213017 Epoch: 0014 loss = 2.177427 Epoch: 0015 loss = 2.198853 Epoch: 0016 loss = 2.030714 Epoch: 0017 loss = 1.964916 Epoch: 0018 loss = 1.903209 Epoch: 0019 loss = 1.902373 Epoch: 0020 loss = 1.778085 Epoch: 0021 loss = 1.725913 Epoch: 0022 loss = 1.744313 Epoch: 0023 loss = 1.646813 Epoch: 0024 loss = 1.606654 Epoch: 0025 loss = 1.525460 Epoch: 0026 loss = 1.460560 Epoch: 0027 loss = 1.428008 Epoch: 0028 loss = 1.362813 Epoch: 0029 loss = 1.315524 Epoch: 0030 loss = 1.295405 Epoch: 0031 loss = 1.200210 Epoch: 0032 loss = 1.158323 Epoch: 0033 loss = 1.108451 Epoch: 0034 loss = 0.951094 Epoch: 0035 loss = 0.981314 Epoch: 0036 loss = 0.951581 Epoch: 0037 loss = 0.815680 Epoch: 0038 loss = 0.751278 Epoch: 0039 loss = 0.715567 Epoch: 0040 loss = 0.625560 Epoch: 0041 loss = 0.587802 Epoch: 0042 loss = 0.575087 Epoch: 0043 loss = 0.506960 Epoch: 0044 loss = 0.455274 Epoch: 0045 loss = 0.419664 Epoch: 0046 loss = 0.354564 Epoch: 0047 loss = 0.351857 Epoch: 0048 loss = 0.297130 Epoch: 0049 loss = 0.278830 Epoch: 0050 loss = 0.283241 Epoch: 0051 loss = 0.212995 Epoch: 0052 loss = 0.218633 Epoch: 0053 loss = 0.181479 Epoch: 0054 loss = 0.162638 Epoch: 0055 loss = 0.146464 Epoch: 0056 loss = 0.134797 Epoch: 0057 loss = 0.101864 Epoch: 0058 loss = 0.116349 Epoch: 0059 loss = 0.098998 Epoch: 0060 loss = 0.104876 Epoch: 0061 loss = 0.083773 Epoch: 0062 loss = 0.074446 Epoch: 0063 loss = 0.078839 Epoch: 0064 loss = 0.073547 Epoch: 0065 loss = 0.064128 Epoch: 0066 loss = 0.060836 Epoch: 0067 loss = 0.053429 Epoch: 0068 loss = 0.049494 Epoch: 0069 loss = 0.056137 Epoch: 0070 loss = 0.049774 Epoch: 0071 loss = 0.042814 Epoch: 0072 loss = 0.041327 Epoch: 0073 loss = 0.044661 Epoch: 0074 loss = 0.036507 Epoch: 0075 loss = 0.035213 Epoch: 0076 loss = 0.033580 Epoch: 0077 loss = 0.028917 Epoch: 0078 loss = 0.030519 Epoch: 0079 loss = 0.031743 Epoch: 0080 loss = 0.024252 Epoch: 0081 loss = 0.028838 Epoch: 0082 loss = 0.026516 Epoch: 0083 loss = 0.022355 Epoch: 0084 loss = 0.023337 Epoch: 0085 loss = 0.022203 Epoch: 0086 loss = 0.022297 Epoch: 0087 loss = 0.020514 Epoch: 0088 loss = 0.021469 Epoch: 0089 loss = 0.019400 Epoch: 0090 loss = 0.017382 Epoch: 0091 loss = 0.017562 Epoch: 0092 loss = 0.017067 Epoch: 0093 loss = 0.017831 Epoch: 0094 loss = 0.016011 Epoch: 0095 loss = 0.016058 Epoch: 0096 loss = 0.016809 Epoch: 0097 loss = 0.016994 Epoch: 0098 loss = 0.015828 Epoch: 0099 loss = 0.015224 Epoch: 0100 loss = 0.014666 Epoch: 0101 loss = 0.013139 Epoch: 0102 loss = 0.014083 Epoch: 0103 loss = 0.014809 Epoch: 0104 loss = 0.013538 Epoch: 0105 loss = 0.012494 Epoch: 0106 loss = 0.014155 Epoch: 0107 loss = 0.012296 Epoch: 0108 loss = 0.012803 Epoch: 0109 loss = 0.012513 Epoch: 0110 loss = 0.011238 Epoch: 0111 loss = 0.010483 Epoch: 0112 loss = 0.011260 Epoch: 0113 loss = 0.010968 Epoch: 0114 loss = 0.010077 Epoch: 0115 loss = 0.011757 Epoch: 0116 loss = 0.010623 Epoch: 0117 loss = 0.009586 Epoch: 0118 loss = 0.011541 Epoch: 0119 loss = 0.009942 Epoch: 0120 loss = 0.009321 Epoch: 0121 loss = 0.009212 Epoch: 0122 loss = 0.010276 Epoch: 0123 loss = 0.008764 Epoch: 0124 loss = 0.009939 Epoch: 0125 loss = 0.009529 Epoch: 0126 loss = 0.009097 Epoch: 0127 loss = 0.009702 Epoch: 0128 loss = 0.008433 Epoch: 0129 loss = 0.007843 Epoch: 0130 loss = 0.007821 Epoch: 0131 loss = 0.007701 Epoch: 0132 loss = 0.008602 Epoch: 0133 loss = 0.007557 Epoch: 0134 loss = 0.007202 Epoch: 0135 loss = 0.007281 Epoch: 0136 loss = 0.008016 Epoch: 0137 loss = 0.007663 Epoch: 0138 loss = 0.006558 Epoch: 0139 loss = 0.006710 Epoch: 0140 loss = 0.007544 Epoch: 0141 loss = 0.007100 Epoch: 0142 loss = 0.006660 Epoch: 0143 loss = 0.006323 Epoch: 0144 loss = 0.007174 Epoch: 0145 loss = 0.007030 Epoch: 0146 loss = 0.006387 Epoch: 0147 loss = 0.006057 Epoch: 0148 loss = 0.006924 Epoch: 0149 loss = 0.006338 Epoch: 0150 loss = 0.006006 Epoch: 0151 loss = 0.006640 Epoch: 0152 loss = 0.005767 Epoch: 0153 loss = 0.006193 Epoch: 0154 loss = 0.006098 Epoch: 0155 loss = 0.005454 Epoch: 0156 loss = 0.005646 Epoch: 0157 loss = 0.005594 Epoch: 0158 loss = 0.005090 Epoch: 0159 loss = 0.005108 Epoch: 0160 loss = 0.005298 Epoch: 0161 loss = 0.005076 Epoch: 0162 loss = 0.004849 Epoch: 0163 loss = 0.004957 Epoch: 0164 loss = 0.004956 Epoch: 0165 loss = 0.004824 Epoch: 0166 loss = 0.005442 Epoch: 0167 loss = 0.004251 Epoch: 0168 loss = 0.004888 ...... Epoch: 0302 loss = 0.001038 Epoch: 0303 loss = 0.001023 Epoch: 0304 loss = 0.000993 Epoch: 0305 loss = 0.000966 Epoch: 0306 loss = 0.001011 Epoch: 0307 loss = 0.001065 Epoch: 0308 loss = 0.001040 Epoch: 0309 loss = 0.000844 Epoch: 0310 loss = 0.000965 Epoch: 0311 loss = 0.001100 Epoch: 0312 loss = 0.001046 Epoch: 0313 loss = 0.000882 Epoch: 0314 loss = 0.000906 Epoch: 0315 loss = 0.000894 Epoch: 0316 loss = 0.000892 Epoch: 0317 loss = 0.000900 Epoch: 0318 loss = 0.001025 Epoch: 0319 loss = 0.000956 Epoch: 0320 loss = 0.001024 Epoch: 0321 loss = 0.000897 Epoch: 0322 loss = 0.000921 Epoch: 0323 loss = 0.000793 Epoch: 0324 loss = 0.000866 Epoch: 0325 loss = 0.000796 Epoch: 0326 loss = 0.000834 Epoch: 0327 loss = 0.000856 Epoch: 0328 loss = 0.000763 Epoch: 0329 loss = 0.000883 Epoch: 0330 loss = 0.000763 Epoch: 0331 loss = 0.000690 Epoch: 0332 loss = 0.000756 Epoch: 0333 loss = 0.000788 Epoch: 0334 loss = 0.000847 Epoch: 0335 loss = 0.000725 Epoch: 0336 loss = 0.000763 Epoch: 0337 loss = 0.000725 Epoch: 0338 loss = 0.000656 Epoch: 0339 loss = 0.000789 Epoch: 0340 loss = 0.000730 Epoch: 0341 loss = 0.000705 Epoch: 0342 loss = 0.000837 Epoch: 0343 loss = 0.000703 Epoch: 0344 loss = 0.000695 Epoch: 0345 loss = 0.000679 Epoch: 0346 loss = 0.000641 Epoch: 0347 loss = 0.000685 Epoch: 0348 loss = 0.000693 Epoch: 0349 loss = 0.000647 Epoch: 0350 loss = 0.000615 Epoch: 0351 loss = 0.000744 Epoch: 0352 loss = 0.000730 Epoch: 0353 loss = 0.000637 Epoch: 0354 loss = 0.000790 Epoch: 0355 loss = 0.000594 Epoch: 0356 loss = 0.000795 Epoch: 0357 loss = 0.000631 Epoch: 0358 loss = 0.000591 Epoch: 0359 loss = 0.000648 Epoch: 0360 loss = 0.000670 Epoch: 0361 loss = 0.000523 Epoch: 0362 loss = 0.000529 Epoch: 0363 loss = 0.000568 Epoch: 0364 loss = 0.000566 Epoch: 0365 loss = 0.000552 Epoch: 0366 loss = 0.000576 ...... Epoch: 0554 loss = 0.000147 Epoch: 0555 loss = 0.000168 Epoch: 0556 loss = 0.000158 Epoch: 0557 loss = 0.000180 Epoch: 0558 loss = 0.000146 Epoch: 0559 loss = 0.000140 Epoch: 0560 loss = 0.000141 Epoch: 0561 loss = 0.000144 Epoch: 0562 loss = 0.000151 Epoch: 0563 loss = 0.000136 Epoch: 0564 loss = 0.000153 Epoch: 0565 loss = 0.000130 Epoch: 0566 loss = 0.000137 Epoch: 0567 loss = 0.000128 Epoch: 0568 loss = 0.000133 Epoch: 0569 loss = 0.000125 Epoch: 0570 loss = 0.000131 Epoch: 0571 loss = 0.000143 Epoch: 0572 loss = 0.000132 Epoch: 0573 loss = 0.000128 Epoch: 0574 loss = 0.000135 Epoch: 0575 loss = 0.000132 Epoch: 0576 loss = 0.000123 Epoch: 0577 loss = 0.000128 Epoch: 0578 loss = 0.000117 Epoch: 0579 loss = 0.000126 Epoch: 0580 loss = 0.000153 Epoch: 0581 loss = 0.000123 Epoch: 0582 loss = 0.000133 Epoch: 0583 loss = 0.000122 Epoch: 0584 loss = 0.000132 Epoch: 0585 loss = 0.000117 Epoch: 0586 loss = 0.000129 Epoch: 0587 loss = 0.000124 Epoch: 0588 loss = 0.000119 Epoch: 0589 loss = 0.000127 Epoch: 0590 loss = 0.000123 Epoch: 0591 loss = 0.000102 Epoch: 0592 loss = 0.000128 Epoch: 0593 loss = 0.000130 Epoch: 0594 loss = 0.000140 Epoch: 0595 loss = 0.000116 Epoch: 0596 loss = 0.000104 Epoch: 0597 loss = 0.000110 Epoch: 0598 loss = 0.000128 Epoch: 0599 loss = 0.000129 Epoch: 0600 loss = 0.000113 Epoch: 0601 loss = 0.000107 Epoch: 0602 loss = 0.000112 Epoch: 0603 loss = 0.000111 Epoch: 0604 loss = 0.000113 Epoch: 0605 loss = 0.000116 Epoch: 0606 loss = 0.000121 Epoch: 0607 loss = 0.000119 Epoch: 0608 loss = 0.000119 Epoch: 0609 loss = 0.000123 Epoch: 0610 loss = 0.000108 Epoch: 0611 loss = 0.000125 Epoch: 0612 loss = 0.000108 Epoch: 0613 loss = 0.000118 Epoch: 0614 loss = 0.000108 Epoch: 0615 loss = 0.000119 Epoch: 0616 loss = 0.000110 Epoch: 0617 loss = 0.000111 Epoch: 0618 loss = 0.000105 Epoch: 0619 loss = 0.000103 Epoch: 0620 loss = 0.000097 Epoch: 0621 loss = 0.000112 Epoch: 0622 loss = 0.000092 Epoch: 0623 loss = 0.000105 Epoch: 0624 loss = 0.000108 Epoch: 0625 loss = 0.000101 Epoch: 0626 loss = 0.000089 Epoch: 0627 loss = 0.000105 Epoch: 0628 loss = 0.000097 Epoch: 0629 loss = 0.000103 Epoch: 0630 loss = 0.000109 Epoch: 0631 loss = 0.000102 Epoch: 0632 loss = 0.000087 ...... Epoch: 0764 loss = 0.000061 Epoch: 0765 loss = 0.000063 Epoch: 0766 loss = 0.000061 Epoch: 0767 loss = 0.000060 Epoch: 0768 loss = 0.000061 Epoch: 0769 loss = 0.000064 Epoch: 0770 loss = 0.000058 Epoch: 0771 loss = 0.000061 Epoch: 0772 loss = 0.000064 Epoch: 0773 loss = 0.000064 Epoch: 0774 loss = 0.000063 Epoch: 0775 loss = 0.000058 Epoch: 0776 loss = 0.000057 Epoch: 0777 loss = 0.000060 Epoch: 0778 loss = 0.000058 Epoch: 0779 loss = 0.000061 Epoch: 0780 loss = 0.000061 Epoch: 0781 loss = 0.000059 Epoch: 0782 loss = 0.000058 Epoch: 0783 loss = 0.000060 Epoch: 0784 loss = 0.000055 Epoch: 0785 loss = 0.000063 Epoch: 0786 loss = 0.000056 Epoch: 0787 loss = 0.000056 Epoch: 0788 loss = 0.000058 Epoch: 0789 loss = 0.000060 Epoch: 0790 loss = 0.000057 Epoch: 0791 loss = 0.000055 Epoch: 0792 loss = 0.000050 Epoch: 0793 loss = 0.000050 Epoch: 0794 loss = 0.000051 Epoch: 0795 loss = 0.000058 Epoch: 0796 loss = 0.000052 Epoch: 0797 loss = 0.000057 Epoch: 0798 loss = 0.000054 Epoch: 0799 loss = 0.000051 Epoch: 0800 loss = 0.000057 Epoch: 0801 loss = 0.000055 Epoch: 0802 loss = 0.000052 Epoch: 0803 loss = 0.000054 Epoch: 0804 loss = 0.000052 Epoch: 0805 loss = 0.000056 Epoch: 0806 loss = 0.000053 Epoch: 0807 loss = 0.000055 Epoch: 0808 loss = 0.000056 Epoch: 0809 loss = 0.000058 Epoch: 0810 loss = 0.000054 Epoch: 0811 loss = 0.000055 Epoch: 0812 loss = 0.000049 Epoch: 0813 loss = 0.000057 Epoch: 0814 loss = 0.000053 Epoch: 0815 loss = 0.000053 Epoch: 0816 loss = 0.000052 Epoch: 0817 loss = 0.000047 Epoch: 0818 loss = 0.000051 Epoch: 0819 loss = 0.000051 Epoch: 0820 loss = 0.000052 Epoch: 0821 loss = 0.000053 Epoch: 0822 loss = 0.000054 Epoch: 0823 loss = 0.000057 Epoch: 0824 loss = 0.000050 Epoch: 0825 loss = 0.000047 Epoch: 0826 loss = 0.000051 Epoch: 0827 loss = 0.000048 Epoch: 0828 loss = 0.000048 Epoch: 0829 loss = 0.000050 Epoch: 0830 loss = 0.000050 Epoch: 0831 loss = 0.000052 Epoch: 0832 loss = 0.000049 Epoch: 0833 loss = 0.000049 Epoch: 0834 loss = 0.000052 Epoch: 0835 loss = 0.000050 Epoch: 0836 loss = 0.000049 Epoch: 0837 loss = 0.000046 Epoch: 0838 loss = 0.000047 Epoch: 0839 loss = 0.000047 Epoch: 0840 loss = 0.000054 Epoch: 0841 loss = 0.000048 Epoch: 0842 loss = 0.000050 Epoch: 0843 loss = 0.000051 Epoch: 0844 loss = 0.000046 Epoch: 0845 loss = 0.000046 Epoch: 0846 loss = 0.000047 Epoch: 0847 loss = 0.000050 Epoch: 0848 loss = 0.000051 Epoch: 0849 loss = 0.000049 Epoch: 0850 loss = 0.000048 Epoch: 0851 loss = 0.000045 Epoch: 0852 loss = 0.000051 Epoch: 0853 loss = 0.000050 Epoch: 0854 loss = 0.000045 Epoch: 0855 loss = 0.000049 Epoch: 0856 loss = 0.000045 Epoch: 0857 loss = 0.000048 Epoch: 0858 loss = 0.000046 Epoch: 0859 loss = 0.000044 Epoch: 0860 loss = 0.000044 Epoch: 0861 loss = 0.000048 Epoch: 0862 loss = 0.000045 Epoch: 0863 loss = 0.000047 Epoch: 0864 loss = 0.000046 Epoch: 0865 loss = 0.000046 Epoch: 0866 loss = 0.000048 Epoch: 0867 loss = 0.000045 Epoch: 0868 loss = 0.000049 Epoch: 0869 loss = 0.000044 Epoch: 0870 loss = 0.000045 Epoch: 0871 loss = 0.000047 Epoch: 0872 loss = 0.000047 Epoch: 0873 loss = 0.000046 Epoch: 0874 loss = 0.000045 Epoch: 0875 loss = 0.000046 Epoch: 0876 loss = 0.000045 Epoch: 0877 loss = 0.000047 Epoch: 0878 loss = 0.000044 Epoch: 0879 loss = 0.000047 Epoch: 0880 loss = 0.000046 Epoch: 0881 loss = 0.000045 Epoch: 0882 loss = 0.000042 Epoch: 0883 loss = 0.000044 Epoch: 0884 loss = 0.000047 Epoch: 0885 loss = 0.000041 Epoch: 0886 loss = 0.000045 Epoch: 0887 loss = 0.000044 Epoch: 0888 loss = 0.000042 Epoch: 0889 loss = 0.000039 ...... Epoch: 0938 loss = 0.000034 Epoch: 0939 loss = 0.000037 Epoch: 0940 loss = 0.000039 Epoch: 0941 loss = 0.000042 Epoch: 0942 loss = 0.000037 Epoch: 0943 loss = 0.000036 Epoch: 0944 loss = 0.000039 Epoch: 0945 loss = 0.000036 Epoch: 0946 loss = 0.000039 Epoch: 0947 loss = 0.000037 Epoch: 0948 loss = 0.000037 Epoch: 0949 loss = 0.000038 Epoch: 0950 loss = 0.000037 Epoch: 0951 loss = 0.000041 Epoch: 0952 loss = 0.000036 Epoch: 0953 loss = 0.000037 Epoch: 0954 loss = 0.000039 Epoch: 0955 loss = 0.000037 Epoch: 0956 loss = 0.000038 Epoch: 0957 loss = 0.000036 Epoch: 0958 loss = 0.000039 Epoch: 0959 loss = 0.000035 Epoch: 0960 loss = 0.000038 Epoch: 0961 loss = 0.000039 Epoch: 0962 loss = 0.000038 Epoch: 0963 loss = 0.000038 Epoch: 0964 loss = 0.000036 Epoch: 0965 loss = 0.000035 Epoch: 0966 loss = 0.000034 Epoch: 0967 loss = 0.000037 Epoch: 0968 loss = 0.000036 Epoch: 0969 loss = 0.000035 Epoch: 0970 loss = 0.000035 Epoch: 0971 loss = 0.000038 Epoch: 0972 loss = 0.000036 Epoch: 0973 loss = 0.000036 Epoch: 0974 loss = 0.000037 Epoch: 0975 loss = 0.000034 Epoch: 0976 loss = 0.000036 Epoch: 0977 loss = 0.000033 Epoch: 0978 loss = 0.000037 Epoch: 0979 loss = 0.000035 Epoch: 0980 loss = 0.000035 Epoch: 0981 loss = 0.000034 Epoch: 0982 loss = 0.000035 Epoch: 0983 loss = 0.000034 Epoch: 0984 loss = 0.000033 Epoch: 0985 loss = 0.000034 Epoch: 0986 loss = 0.000037 Epoch: 0987 loss = 0.000033 Epoch: 0988 loss = 0.000035 Epoch: 0989 loss = 0.000034 Epoch: 0990 loss = 0.000035 Epoch: 0991 loss = 0.000035 Epoch: 0992 loss = 0.000034 Epoch: 0993 loss = 0.000032 Epoch: 0994 loss = 0.000037 Epoch: 0995 loss = 0.000035 Epoch: 0996 loss = 0.000037 Epoch: 0997 loss = 0.000034 Epoch: 0998 loss = 0.000034 Epoch: 0999 loss = 0.000037 Epoch: 1000 loss = 0.000036
def greedy_decoder(model, enc_input, start_symbol): """ :param model: Transformer Model :param enc_input: The encoder input :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4 :return: The target input 这是一个贪婪解码器(Greedy Decoder),其作用是在没有目标序列输入的情况下,通过逐个生成目标输入词来执行推断。 代码首先对编码器输入进行编码,然后初始化一个空的解码器输入张量dec_input。接下来,代码进入一个循环,不断生成下一个目标词,并将其添加到dec_input中。 在每一步中,模型接收当前的dec_input和编码器输出,并生成解码器输出dec_outputs。 然后,经过一个投影层model.projection后得到projected,并利用max函数找到最大概率对应的词,将其作为下一个目标词。 如果下一个目标词是句号("."),则终止循环,否则继续生成下一个目标词。 需要注意的是,在代码中有一些技术细节,比如使用torch.cat函数进行张量拼接、detach方法从张量中分离数据、squeeze方法移除单个维度等, 这些操作都是为了确保张量的形状和数据类型满足模型的要求。 最后,代码返回生成的目标输入dec_input。 """ enc_outputs, enc_self_attns = model.encoder(enc_input) # 创建了一个尺寸为(1, 0)的张量,并使用与enc_input相同的数据类型来初始化张量。 dec_input = torch.zeros(1, 0).type_as(enc_input.data) terminal = False next_symbol = start_symbol while not terminal: # torch.cat函数来进行张量的拼接操作。 # dec_input.detach()表示从dec_input中分离出数据并创建一个新的张量 # torch.tensor([[next_symbol]],dtype=enc_input.dtype).cuda()表示将next_symbol转换为PyTorch张量并移动到GPU上 dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1) dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs) projected = model.projection(dec_outputs) # squeeze的作用是移除张量中的单个维度 # .max(dim=-1, keepdim=False)[1]:这是max pooling操作。max函数的目的是找到张量在某个维度上的最大值。在这里,我们在最后一个维度(由dim=-1指定)上进行了这个操作。 # keepdim=False表示我们不希望保留这个被操作的维度。 # 所以,这行代码将返回一个与projected.squeeze(0)相同形状的张量,但是在最后一个维度上,每个元素都被替换为该维度上的最大值。 prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1] next_word = prob.data[-1] next_symbol = next_word if next_symbol == tgt_vocab["."]: terminal = True print(next_word) return dec_input
# Test
enc_inputs, _, _ = next(iter(loader))
enc_inputs = enc_inputs.cuda()
for i in range(len(enc_inputs)):
greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(5, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke', '.']
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer', '.']
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。