当前位置:   article > 正文

Attention is all you need论文笔记---Transformer代码详细注释-pytorch版_attention is all you need 代码

attention is all you need 代码

Transformer介绍

1.模型整体架构

在这里插入图片描述

2.Encoder部分

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.Decoder部分

在这里插入图片描述
在这里插入图片描述

4.attention部分

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

5.cross attention

在这里插入图片描述

6.论文中的三处attention

在这里插入图片描述

Transformer代码结构

本图片引用自互联网

上述图片引自知乎(https://zhuanlan.zhihu.com/p/107889011)

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import _LRScheduler


# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
    # enc_input           dec_input         dec_output
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    # 德语 S是解码端的输入,E不是解码端的输出,而是解码端的真实标签,它和最后的输出做loss
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']  # 法语
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
src_vocab_size = len(src_vocab) # 6

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab) #9

src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length

# Transformer Parameters
d_model = 512  # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      # extend的作用是将一个列表当中的元素加到另一个列表末尾
      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs

  def __len__(self):
    return self.enc_inputs.shape[0]

  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True) # batchsize = 2, shuffle = True
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
enc, dec, dec = make_data(sentences)
print(enc)
print(enc[0])
print(enc.shape)
print(enc.shape[0])
  • 1
  • 2
  • 3
  • 4
  • 5
tensor([[1, 2, 3, 4, 0],
        [1, 2, 3, 5, 0]])
tensor([1, 2, 3, 4, 0])
torch.Size([2, 5])
2
  • 1
  • 2
  • 3
  • 4
  • 5
idx2word
  • 1
{0: 'P',
 1: 'i',
 2: 'want',
 3: 'a',
 4: 'beer',
 5: 'coke',
 6: 'S',
 7: 'E',
 8: '.'}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
emb = nn.Embedding(src_vocab_size, d_model)
enc_o = emb(enc) # [batch_size, src_len, d_model]
print(enc_o.shape)
print(enc_o)
  • 1
  • 2
  • 3
  • 4
torch.Size([2, 5, 512])
tensor([[[ 0.2115,  0.1797,  0.3934,  ..., -1.3825,  0.4188, -0.5924],
         [ 0.8305, -0.1925, -0.4087,  ..., -1.2020,  1.5736,  0.8208],
         [-2.3016,  0.3060, -0.2981,  ...,  1.8724, -0.3179, -0.6690],
         [ 0.6748,  0.2370, -1.0590,  ..., -0.2914, -0.0615,  0.2832],
         [-0.9452, -0.1783,  0.4750,  ...,  0.0894,  2.0903, -1.8880]],

        [[ 0.2115,  0.1797,  0.3934,  ..., -1.3825,  0.4188, -0.5924],
         [ 0.8305, -0.1925, -0.4087,  ..., -1.2020,  1.5736,  0.8208],
         [-2.3016,  0.3060, -0.2981,  ...,  1.8724, -0.3179, -0.6690],
         [ 0.5548, -0.0924,  0.0213,  ..., -1.6186,  1.0486, -0.0319],
         [-0.9452, -0.1783,  0.4750,  ...,  0.0894,  2.0903, -1.8880]]],
       grad_fn=<EmbeddingBackward0>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
def plot_position_embedding(position):
    plt.pcolormesh(position[0], cmap = 'RdBu')
    plt.xlabel('Depth')
    plt.xlim((0, 512))
    plt.colorbar()
    plt.show()
# enc_o = enc_o.detach().numpy()
# plot_position_embedding(enc_o)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
'
运行
# 3.位置编码的实现
# PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
# PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # unsqueeze(1) 操作增加了一个维度,使得张量的形状从(max_len,)变为(max_len, 1)。
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # torch.arange(0, d_model, 2).float() 这部分代码创建一个从0到d_model(不包括d_model)的张量,步长为2,并且将所有元素转换为浮点数类型。
        # (-math.log(10000.0) / d_model) 这部分代码计算一个标量值,它是-log(10000.0)除以d_model的结果。torch.exp()函数对上述张量中的每个元素应用指数运算。
        pe[:, 0::2] = torch.sin(position * div_term) # 选取所有的偶数列
        pe[:, 1::2] = torch.cos(position * div_term) # 选取所有的奇数列
        pe = pe.unsqueeze(0).transpose(0, 1) # [5000,512]-->[1,5000,512]-->[5000,1,512]
        # pe.unsqueeze(0)这部分代码在pe这个张量的每个元素前面添加一个维度。这样,原来的pe的形状可能会从(N,)变为(1,N),其中N是pe的元素个数。
        # transpose(0, 1)这部分代码将经过unsqueeze操作后的张量的维度进行转置。换句话说,它交换了新添加的维度和原来的第一个维度的位置。因此,最后的结果是,pe的形状从(1,N)变为(N,1)。
        self.register_buffer('pe', pe)
        # register_buffer是一个方法,定义一个缓冲区,通常在nn.Module类中使用,用于在模型中添加一个持久的(即不会在反向传播时被清除)张量。这通常用于存储一些在多次前向/后向传播中需要保持的数据。

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        # x.size(0)代表输入序列的长度
        return self.dropout(x)

# 4.在得到的attention score矩阵(这是一个对称矩阵)中,pad部分也存在一个score值,如何消去这个pad值呢,可以使用一个符号标记矩阵,将pad填充的部分设置为1,其余正常值部分设置为0,然后将标记为1
# 的位置的地方的值消去(这里是在计算softmax之前把这里设置为 负 无穷大)
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    # seq_q和seq_k的值不一定一致,在交互注意力,q来自解码端,k来自编码端,所以告诉模型这边pad符号信息就可以,解码端的pad信息在交互注意力层没有用到??
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    # eq(0) 是用于对序列中的元素进行零值判断,生成一个与 seq_k 相同大小的张量,其中填充位置对应的元素被标记为True或False。
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], true is masked
    # expand(batch_size, len_q, len_k) 是在给定维度上对该张量进行扩展的操作。这行代码的作用可能是将 pad_attn_mask 进行扩展,使其在 batch_size、len_q 和 len_k
    # 这三个维度上扩展到指定的大小。
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

'''
seq2seq和transformer两者的"预测"都是自回归,当前神经元的输入是上一个神经元的输出
但是在训练时transformer与seq2seq有不同 因为seq2seq在训练时也是一个一个的输入到神经元,是自回归的,当前神经元的输入受上一个神经元输出的影响。
但是transformer在训练时是并行的,一次性将所有数据全部输入进去,所以为了达到更好的效果,我们就用mask将后面的数据进行’遮挡‘
这个mask就应该是上三角矩阵,上三角元素全为1,主对角线全为0,方便之后乘上一个无穷大的数
'''
def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''

    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]

    # np.ones(attn_shape):创建了一个形状为 attn_shape 的全为1的矩阵。这个矩阵将用作上三角矩阵的基础
    # np.triu(..., k=1):使用 np.triu() 函数获取输入矩阵的上三角部分
    # 参数 k=1 表示将主对角线以下的第一条对角线设为0,以此类推,即保留主对角线及其以上的部分,并将其他部分设为0。
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix

    # 将一个 NumPy 数组转换为 PyTorch 张量,并将其类型转换为 byte 类型
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

get_attn_subsequence_mask

# 前面看不到后面的padding,矩阵下面全部为0
# 在mask里,应该被忽略的我们会设成1,应该被保留的会设成0
# 计算的时候,把1的部分设置成一个超级小的数,然后在计算softmax的时候,一个超级小的数的指数会无限接近与0。也就是它对应的attention的权重就是0了,
x = torch.tensor([[7, 6, 0, 0, 0], [1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
get_attn_subsequence_mask(x)
  • 1
  • 2
  • 3
  • 4
  • 5
tensor([[[0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

PositionalEncoding

position_embedding = PositionalEncoding(d_model)
input_tensor = torch.zeros(1, 50, 512)
position = position_embedding.forward(input_tensor)
print(position)
print(position.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
tensor([[[0.0000, 1.1111, 0.0000,  ..., 1.1111, 0.0000, 1.1111],
         [0.0000, 1.1111, 0.0000,  ..., 1.1111, 0.0000, 1.1111],
         [0.0000, 0.0000, 0.0000,  ..., 1.1111, 0.0000, 1.1111],
         ...,
         [0.0000, 1.1111, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 1.1111, 0.0000,  ..., 1.1111, 0.0000, 1.1111],
         [0.0000, 1.1111, 0.0000,  ..., 1.1111, 0.0000, 1.1111]]])
torch.Size([1, 50, 512])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
#这个图和我们原理中展示的横纵坐标是颠倒的
def plot_position_embedding(position):
    plt.pcolormesh(position[0], cmap = 'RdBu')
    plt.xlabel('Depth')
    plt.xlim((0, 512))
    plt.ylabel('Position')
    plt.colorbar()
    plt.show()

plot_position_embedding(position)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这里插入图片描述

pe1 = torch.zeros(5000, 512)
position1 = torch.arange(0, 5000, dtype=torch.float).unsqueeze(1) # unsqueeze(1) 操作增加了一个维度,使得张量的形状从(max_len,)变为(max_len, 1)。
div_term1 = torch.exp(torch.arange(0, 512, 2).float() * (-math.log(10000.0) / 512))
pe1[:, 0::2] = torch.sin(position1 * div_term1)
pe1
  • 1
  • 2
  • 3
  • 4
  • 5
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.4147e-01,  0.0000e+00,  8.2186e-01,  ...,  0.0000e+00,
          1.0366e-04,  0.0000e+00],
        [ 9.0930e-01,  0.0000e+00,  9.3641e-01,  ...,  0.0000e+00,
          2.0733e-04,  0.0000e+00],
        ...,
        [ 9.5625e-01,  0.0000e+00,  9.3594e-01,  ...,  0.0000e+00,
          4.9515e-01,  0.0000e+00],
        [ 2.7050e-01,  0.0000e+00,  8.2251e-01,  ...,  0.0000e+00,
          4.9524e-01,  0.0000e+00],
        [-6.6395e-01,  0.0000e+00,  1.4615e-03,  ...,  0.0000e+00,
          4.9533e-01,  0.0000e+00]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
pe1[:, 1::2] = torch.cos(position1 * div_term1)
pe1
  • 1
  • 2
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 9.5625e-01, -2.9254e-01,  9.3594e-01,  ...,  8.5926e-01,
          4.9515e-01,  8.6881e-01],
        [ 2.7050e-01, -9.6272e-01,  8.2251e-01,  ...,  8.5920e-01,
          4.9524e-01,  8.6876e-01],
        [-6.6395e-01, -7.4778e-01,  1.4615e-03,  ...,  8.5915e-01,
          4.9533e-01,  8.6871e-01]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
pe1.shape
  • 1
torch.Size([1, 5000, 512])
  • 1
pe1 = pe1.reshape(1,5000,512)
pe1.shape
  • 1
  • 2
torch.Size([1, 5000, 512])
  • 1
pe1.reshape(1, 5000, 512)
plot_position_embedding(pe1)
  • 1
  • 2

在这里插入图片描述

get_attn_pad_mask

inputs = torch.tensor([[7, 6, 0, 0, 0], [1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
get_attn_pad_mask(inputs, inputs)
  • 1
  • 2
tensor([[[False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True]],

        [[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
inputs.shape
  • 1
torch.Size([3, 5])
  • 1
batch, len1 = inputs.size()
print(batch)
print(len1)
  • 1
  • 2
  • 3
# 7.缩放点积注意力
# q是query,k,v代表k和value,q和k做完矩阵乘法后,做mask
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores : [batch_size, n_heads, len_q, len_k]
        # 把被mask的地方设置为无限小,使得softmax之后基本为0
        scores.masked_fill_(attn_mask, -1e9)  # Fills elements of self tensor with value where mask is True.

        # dim=-1 表示对最后一个维度进行操作,即对每行内部进行 Softmax 运算。
        '''例:
        [[1.0, 2.0, 3.0, 4.0],            [[0.0321, 0.0871, 0.2369, 0.6439],
        [5.0, 6.0, 7.0, 8.0],             [0.0321, 0.0871, 0.2369, 0.6439],
        [9.0, 10.0, 11.0, 12.0]]            [0.0321, 0.0871, 0.2369, 0.6439]]
        '''
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)  # [batch_size, n_heads, len_q, d_v]
        return context, attn


# 6.多头注意力机制的实现
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        # 输入进来的QKV矩阵是相等的,我们会使用linear做一个映射得到参数矩阵Wq,Wk,Wv
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        # 分头的步骤,首先映射分头,然后计算atten_scores,然后计算atten_value
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        # 这里先映射,后分头,需要注意的是q和k的维度要保持一致(因为q和k要计算内积),这里它们两个的维度都是d_k
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # Q: [batch_size, n_heads, len_q, d_k]
        # self.W_Q(input_Q)表示将输入的查询input_Q通过之前定义的W_Q线性层进行线性变换。
        # 然后,.view(batch_size, -1, n_heads,d_k)的作用是将得到的结果重新塑造成一个新的形状,
        # 其中batch_size表示批量大小,-1表示自动推断该维度的大小,n_heads表示注意力头的数量,d_k表示每个注意力头的维度。
        # 这一步通常是为了将线性变换后的结果准备成适合进行多头注意力计算的形状。
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        # .repeat(1, n_heads, 1, 1)表示沿着各个维度复制数据,具体来说,第一个参数1表示不复制,n_heads表示复制n_heads次,后面两个1表示不复制。
        # 这样可以将刚刚增加的维度进行复制,使得形状变为(batch_size, n_heads, seq_length, seq_length),确保每个注意力头都可以使用相同的掩码
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)  # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        # 然后进行缩放点积注意力计算 7
                context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)       # context: [batch_size, n_heads, len_q, d_v]
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)  # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context)  # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False), # d_model = 512, d_ff = 2048
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]

# 5.EncoderLayer:包含两个部分,多头注意力机制和前馈神经网络
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention() # 多头自注意力层
        self.pos_ffn = PoswiseFeedForwardNet() # 全连接的前馈神经网络

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        # 看一下多头注意力机制的实现 6
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn


# 2.Encoder包含三个部分:词向量embedding,位置编码部分,注意力层及后续的前馈神经网络
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model) # 这里其实就是定义生成一个矩阵,大小为src_vocab_size*d_model,src_vocab_size是源词表当中所有单词的个数,这里是6
        self.pos_emb = PositionalEncoding(d_model) # 位置编码,这里使用的是固定的正余弦函数,也可以使用类似词向量的nn.Embedding获得一个可以更新学习的位置编码
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) # 使用ModuleList对多个encoder进行堆叠因为后续的encoder并没有使用词向量和位置编码,所以抽离出来

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        # 这里通过src_emb,进行索引定位,enc_ouputs的输出形状为[batch_size,src_len,d_model],就是将输入句子中的每个词转化为词向量
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]

        # 位置编码,把两者相加放入到了这个函数里面,这里可以去看看这个位置编码的函数3
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]

        # get_attn_pad_mask是为了得到句子中pad的位置信息,给到模型后面,在计算自注意力和交互注意力的时候去掉pad的符号影响,转到函数4
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # 去看EncoderLayer层 5
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

# 9.Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
        # get_attn_pad_mask 自注意力层的时候的pad部分
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]

        # get_attn_subsequence_mask 就是当前单词之后的看不到,使用一个上三角为1的矩阵
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]

        # 两个矩阵相加,大于0的为1,不大于0的为0,为1的之后就会被fill成负无穷大
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len],
            # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

# 1.从整体网络结构来看,分为三个部分:编码层,解码层,输出层
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().cuda()  # 编码层
        self.decoder = Decoder().cuda()  # 解码层
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()  # 输出层,d_model是我们解码层每个token输出的维度大小,之后会做一个tgt_vocab_size(9)大小的softmax
        # 输出层的作用就是将512维的输出映射到tgt_vocab_size的维度,然后做一个softmax,来预测输出的应该是哪一个词

    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # 这里有两个数据进行输入,一个是enc_inputs,形状为[batch_size,src_len],主要是作为编码端的输入,一个是dec_inputs,形状为[batch_size,tgt_len],主要作为解码端的输入

        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)

        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        # enc_inputs作为输入 形状为[batch_size,src_len],输出由自己的函数内部决定
        # enc_outputs是主要的输出,enc_self_attns是QK矩阵转置相乘之后softmax之后的矩阵值,代表的是每个单词和其他单词的相关性,好像主要是为了可视化
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)

        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        # decoder的主要输入由两个部分,一是decoder的输入(模型图中的outputs),另一个是encoder的输出(cross attention部分)
        # dec_outputs是decoder的主要输出,用户后续的linear映射,dec_self_attns类比于enc_self_attns是查看每个单词对decoder中输入的其他单词的相关性,dec_enc_attns是decoder中每个单词对encoder中每个单词的相关性
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)

        # dec_outputs做映射到词表大小
        dec_logits = self.projection(dec_outputs)  # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns


model = Transformer().cuda()
# 在写模型的时候遵循两个规则
# 1.构建模型的时候要先从整体到局部,先把大的框架搭起来,再去完善细节部分
# 2.一定要搞清楚数据的流动形状,就是经过某个模型,要清楚输入是什么形状,输出是什么形状(知道输出是什么形状,就可以知道一部分的输入是什么形状,该怎么写代码)
criterion = nn.CrossEntropyLoss(ignore_index=0)
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230

多头注意力机制的验证

i = torch.rand(1, 60)
print(i.shape)
  • 1
  • 2
i = torch.rand(1, 60)
print(i.shape)
  • 1
  • 2
i = torch.rand(1, 60).to('cuda')
print(i.shape)
a = get_attn_pad_mask(i, i).to('cuda')
temp_mha = MultiHeadAttention().to('cuda')
#创建一份虚拟数据
y = torch.rand(1, 60, 512).to('cuda')
#开始计算,把y既当q,又当k,v
output, attn = temp_mha.forward(y, y, y, attn_mask = a)
print(output.shape)
print(attn.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
torch.Size([1, 60])
torch.Size([1, 60, 512])
torch.Size([1, 8, 60, 60])
  • 1
  • 2
  • 3

自定义学习率调度程序配合优化器

效果更好

class CustomizedSchedule1:
    def __init__(self, d_model, warmup_steps=4000):
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = (step + 1) ** -0.5  # 避免出现 0 的负指数
        arg2 = step * (self.warmup_steps ** (-1.5))
        arg3 = (self.d_model ** -0.5)
        return arg3 * min(arg1, arg2)
learning_rate_fn = CustomizedSchedule1(d_model)

# 创建优化器和学习率调度器
optimizer1 = torch.optim.Adam(model.parameters(), lr=0)  # 初始学习率设为 0,后续由学习率调度器控制
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer1, lr_lambda=learning_rate_fn)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
class CustomizedSchedule1:
    def __init__(self, d_model, warmup_steps=4000):
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = (step + 1) ** -0.5  # 避免出现 0 的负指数
        arg2 = step * (self.warmup_steps ** (-1.5))
        arg3 = (self.d_model ** -0.5)
        return arg3 * min(arg1, arg2)
learning_rate_fn = CustomizedSchedule1(d_model)

# 创建优化器和学习率调度器
optimizer1 = torch.optim.Adam(model.parameters(), lr=0)  # 初始学习率设为 0,后续由学习率调度器控制
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer1, lr_lambda=learning_rate_fn)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
# 这个图就可以看出,受warm_steps的影响,在达到warm_steps的步数时会缓慢下降(训练一个batch_size就是一步)
temp_learning_rate_schedule = CustomizedSchedule1(d_model)
#下面是学习率的设计图
# 绘制学习率随训练步数变化的曲线
steps = torch.arange(40000, dtype=torch.float32)
learning_rates = [temp_learning_rate_schedule(step).item() for step in steps]

plt.plot(steps.numpy(), learning_rates)
plt.ylabel("Learning rate")
plt.xlabel("Train step")
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

在这里插入图片描述
训练

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        # view(-1)的作用是将dec_outputs转换为一维张量,即将其所有的元素都拉平成一个一维向量。
        loss = criterion(outputs, dec_outputs.view(-1))
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 调用学习率调度器更新学习率
        scheduler.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
Epoch: 0001 loss = 2.700844
Epoch: 0002 loss = 2.676570
Epoch: 0003 loss = 2.627945
Epoch: 0004 loss = 2.606202
Epoch: 0005 loss = 2.650645
Epoch: 0006 loss = 2.588106
Epoch: 0007 loss = 2.445351
Epoch: 0008 loss = 2.508237
Epoch: 0009 loss = 2.447822
Epoch: 0010 loss = 2.325812
Epoch: 0011 loss = 2.384464
Epoch: 0012 loss = 2.271541
Epoch: 0013 loss = 2.213017
Epoch: 0014 loss = 2.177427
Epoch: 0015 loss = 2.198853
Epoch: 0016 loss = 2.030714
Epoch: 0017 loss = 1.964916
Epoch: 0018 loss = 1.903209
Epoch: 0019 loss = 1.902373
Epoch: 0020 loss = 1.778085
Epoch: 0021 loss = 1.725913
Epoch: 0022 loss = 1.744313
Epoch: 0023 loss = 1.646813
Epoch: 0024 loss = 1.606654
Epoch: 0025 loss = 1.525460
Epoch: 0026 loss = 1.460560
Epoch: 0027 loss = 1.428008
Epoch: 0028 loss = 1.362813
Epoch: 0029 loss = 1.315524
Epoch: 0030 loss = 1.295405
Epoch: 0031 loss = 1.200210
Epoch: 0032 loss = 1.158323
Epoch: 0033 loss = 1.108451
Epoch: 0034 loss = 0.951094
Epoch: 0035 loss = 0.981314
Epoch: 0036 loss = 0.951581
Epoch: 0037 loss = 0.815680
Epoch: 0038 loss = 0.751278
Epoch: 0039 loss = 0.715567
Epoch: 0040 loss = 0.625560
Epoch: 0041 loss = 0.587802
Epoch: 0042 loss = 0.575087
Epoch: 0043 loss = 0.506960
Epoch: 0044 loss = 0.455274
Epoch: 0045 loss = 0.419664
Epoch: 0046 loss = 0.354564
Epoch: 0047 loss = 0.351857
Epoch: 0048 loss = 0.297130
Epoch: 0049 loss = 0.278830
Epoch: 0050 loss = 0.283241
Epoch: 0051 loss = 0.212995
Epoch: 0052 loss = 0.218633
Epoch: 0053 loss = 0.181479
Epoch: 0054 loss = 0.162638
Epoch: 0055 loss = 0.146464
Epoch: 0056 loss = 0.134797
Epoch: 0057 loss = 0.101864
Epoch: 0058 loss = 0.116349
Epoch: 0059 loss = 0.098998
Epoch: 0060 loss = 0.104876
Epoch: 0061 loss = 0.083773
Epoch: 0062 loss = 0.074446
Epoch: 0063 loss = 0.078839
Epoch: 0064 loss = 0.073547
Epoch: 0065 loss = 0.064128
Epoch: 0066 loss = 0.060836
Epoch: 0067 loss = 0.053429
Epoch: 0068 loss = 0.049494
Epoch: 0069 loss = 0.056137
Epoch: 0070 loss = 0.049774
Epoch: 0071 loss = 0.042814
Epoch: 0072 loss = 0.041327
Epoch: 0073 loss = 0.044661
Epoch: 0074 loss = 0.036507
Epoch: 0075 loss = 0.035213
Epoch: 0076 loss = 0.033580
Epoch: 0077 loss = 0.028917
Epoch: 0078 loss = 0.030519
Epoch: 0079 loss = 0.031743
Epoch: 0080 loss = 0.024252
Epoch: 0081 loss = 0.028838
Epoch: 0082 loss = 0.026516
Epoch: 0083 loss = 0.022355
Epoch: 0084 loss = 0.023337
Epoch: 0085 loss = 0.022203
Epoch: 0086 loss = 0.022297
Epoch: 0087 loss = 0.020514
Epoch: 0088 loss = 0.021469
Epoch: 0089 loss = 0.019400
Epoch: 0090 loss = 0.017382
Epoch: 0091 loss = 0.017562
Epoch: 0092 loss = 0.017067
Epoch: 0093 loss = 0.017831
Epoch: 0094 loss = 0.016011
Epoch: 0095 loss = 0.016058
Epoch: 0096 loss = 0.016809
Epoch: 0097 loss = 0.016994
Epoch: 0098 loss = 0.015828
Epoch: 0099 loss = 0.015224
Epoch: 0100 loss = 0.014666
Epoch: 0101 loss = 0.013139
Epoch: 0102 loss = 0.014083
Epoch: 0103 loss = 0.014809
Epoch: 0104 loss = 0.013538
Epoch: 0105 loss = 0.012494
Epoch: 0106 loss = 0.014155
Epoch: 0107 loss = 0.012296
Epoch: 0108 loss = 0.012803
Epoch: 0109 loss = 0.012513
Epoch: 0110 loss = 0.011238
Epoch: 0111 loss = 0.010483
Epoch: 0112 loss = 0.011260
Epoch: 0113 loss = 0.010968
Epoch: 0114 loss = 0.010077
Epoch: 0115 loss = 0.011757
Epoch: 0116 loss = 0.010623
Epoch: 0117 loss = 0.009586
Epoch: 0118 loss = 0.011541
Epoch: 0119 loss = 0.009942
Epoch: 0120 loss = 0.009321
Epoch: 0121 loss = 0.009212
Epoch: 0122 loss = 0.010276
Epoch: 0123 loss = 0.008764
Epoch: 0124 loss = 0.009939
Epoch: 0125 loss = 0.009529
Epoch: 0126 loss = 0.009097
Epoch: 0127 loss = 0.009702
Epoch: 0128 loss = 0.008433
Epoch: 0129 loss = 0.007843
Epoch: 0130 loss = 0.007821
Epoch: 0131 loss = 0.007701
Epoch: 0132 loss = 0.008602
Epoch: 0133 loss = 0.007557
Epoch: 0134 loss = 0.007202
Epoch: 0135 loss = 0.007281
Epoch: 0136 loss = 0.008016
Epoch: 0137 loss = 0.007663
Epoch: 0138 loss = 0.006558
Epoch: 0139 loss = 0.006710
Epoch: 0140 loss = 0.007544
Epoch: 0141 loss = 0.007100
Epoch: 0142 loss = 0.006660
Epoch: 0143 loss = 0.006323
Epoch: 0144 loss = 0.007174
Epoch: 0145 loss = 0.007030
Epoch: 0146 loss = 0.006387
Epoch: 0147 loss = 0.006057
Epoch: 0148 loss = 0.006924
Epoch: 0149 loss = 0.006338
Epoch: 0150 loss = 0.006006
Epoch: 0151 loss = 0.006640
Epoch: 0152 loss = 0.005767
Epoch: 0153 loss = 0.006193
Epoch: 0154 loss = 0.006098
Epoch: 0155 loss = 0.005454
Epoch: 0156 loss = 0.005646
Epoch: 0157 loss = 0.005594
Epoch: 0158 loss = 0.005090
Epoch: 0159 loss = 0.005108
Epoch: 0160 loss = 0.005298
Epoch: 0161 loss = 0.005076
Epoch: 0162 loss = 0.004849
Epoch: 0163 loss = 0.004957
Epoch: 0164 loss = 0.004956
Epoch: 0165 loss = 0.004824
Epoch: 0166 loss = 0.005442
Epoch: 0167 loss = 0.004251
Epoch: 0168 loss = 0.004888
......
Epoch: 0302 loss = 0.001038
Epoch: 0303 loss = 0.001023
Epoch: 0304 loss = 0.000993
Epoch: 0305 loss = 0.000966
Epoch: 0306 loss = 0.001011
Epoch: 0307 loss = 0.001065
Epoch: 0308 loss = 0.001040
Epoch: 0309 loss = 0.000844
Epoch: 0310 loss = 0.000965
Epoch: 0311 loss = 0.001100
Epoch: 0312 loss = 0.001046
Epoch: 0313 loss = 0.000882
Epoch: 0314 loss = 0.000906
Epoch: 0315 loss = 0.000894
Epoch: 0316 loss = 0.000892
Epoch: 0317 loss = 0.000900
Epoch: 0318 loss = 0.001025
Epoch: 0319 loss = 0.000956
Epoch: 0320 loss = 0.001024
Epoch: 0321 loss = 0.000897
Epoch: 0322 loss = 0.000921
Epoch: 0323 loss = 0.000793
Epoch: 0324 loss = 0.000866
Epoch: 0325 loss = 0.000796
Epoch: 0326 loss = 0.000834
Epoch: 0327 loss = 0.000856
Epoch: 0328 loss = 0.000763
Epoch: 0329 loss = 0.000883
Epoch: 0330 loss = 0.000763
Epoch: 0331 loss = 0.000690
Epoch: 0332 loss = 0.000756
Epoch: 0333 loss = 0.000788
Epoch: 0334 loss = 0.000847
Epoch: 0335 loss = 0.000725
Epoch: 0336 loss = 0.000763
Epoch: 0337 loss = 0.000725
Epoch: 0338 loss = 0.000656
Epoch: 0339 loss = 0.000789
Epoch: 0340 loss = 0.000730
Epoch: 0341 loss = 0.000705
Epoch: 0342 loss = 0.000837
Epoch: 0343 loss = 0.000703
Epoch: 0344 loss = 0.000695
Epoch: 0345 loss = 0.000679
Epoch: 0346 loss = 0.000641
Epoch: 0347 loss = 0.000685
Epoch: 0348 loss = 0.000693
Epoch: 0349 loss = 0.000647
Epoch: 0350 loss = 0.000615
Epoch: 0351 loss = 0.000744
Epoch: 0352 loss = 0.000730
Epoch: 0353 loss = 0.000637
Epoch: 0354 loss = 0.000790
Epoch: 0355 loss = 0.000594
Epoch: 0356 loss = 0.000795
Epoch: 0357 loss = 0.000631
Epoch: 0358 loss = 0.000591
Epoch: 0359 loss = 0.000648
Epoch: 0360 loss = 0.000670
Epoch: 0361 loss = 0.000523
Epoch: 0362 loss = 0.000529
Epoch: 0363 loss = 0.000568
Epoch: 0364 loss = 0.000566
Epoch: 0365 loss = 0.000552
Epoch: 0366 loss = 0.000576
......
Epoch: 0554 loss = 0.000147
Epoch: 0555 loss = 0.000168
Epoch: 0556 loss = 0.000158
Epoch: 0557 loss = 0.000180
Epoch: 0558 loss = 0.000146
Epoch: 0559 loss = 0.000140
Epoch: 0560 loss = 0.000141
Epoch: 0561 loss = 0.000144
Epoch: 0562 loss = 0.000151
Epoch: 0563 loss = 0.000136
Epoch: 0564 loss = 0.000153
Epoch: 0565 loss = 0.000130
Epoch: 0566 loss = 0.000137
Epoch: 0567 loss = 0.000128
Epoch: 0568 loss = 0.000133
Epoch: 0569 loss = 0.000125
Epoch: 0570 loss = 0.000131
Epoch: 0571 loss = 0.000143
Epoch: 0572 loss = 0.000132
Epoch: 0573 loss = 0.000128
Epoch: 0574 loss = 0.000135
Epoch: 0575 loss = 0.000132
Epoch: 0576 loss = 0.000123
Epoch: 0577 loss = 0.000128
Epoch: 0578 loss = 0.000117
Epoch: 0579 loss = 0.000126
Epoch: 0580 loss = 0.000153
Epoch: 0581 loss = 0.000123
Epoch: 0582 loss = 0.000133
Epoch: 0583 loss = 0.000122
Epoch: 0584 loss = 0.000132
Epoch: 0585 loss = 0.000117
Epoch: 0586 loss = 0.000129
Epoch: 0587 loss = 0.000124
Epoch: 0588 loss = 0.000119
Epoch: 0589 loss = 0.000127
Epoch: 0590 loss = 0.000123
Epoch: 0591 loss = 0.000102
Epoch: 0592 loss = 0.000128
Epoch: 0593 loss = 0.000130
Epoch: 0594 loss = 0.000140
Epoch: 0595 loss = 0.000116
Epoch: 0596 loss = 0.000104
Epoch: 0597 loss = 0.000110
Epoch: 0598 loss = 0.000128
Epoch: 0599 loss = 0.000129
Epoch: 0600 loss = 0.000113
Epoch: 0601 loss = 0.000107
Epoch: 0602 loss = 0.000112
Epoch: 0603 loss = 0.000111
Epoch: 0604 loss = 0.000113
Epoch: 0605 loss = 0.000116
Epoch: 0606 loss = 0.000121
Epoch: 0607 loss = 0.000119
Epoch: 0608 loss = 0.000119
Epoch: 0609 loss = 0.000123
Epoch: 0610 loss = 0.000108
Epoch: 0611 loss = 0.000125
Epoch: 0612 loss = 0.000108
Epoch: 0613 loss = 0.000118
Epoch: 0614 loss = 0.000108
Epoch: 0615 loss = 0.000119
Epoch: 0616 loss = 0.000110
Epoch: 0617 loss = 0.000111
Epoch: 0618 loss = 0.000105
Epoch: 0619 loss = 0.000103
Epoch: 0620 loss = 0.000097
Epoch: 0621 loss = 0.000112
Epoch: 0622 loss = 0.000092
Epoch: 0623 loss = 0.000105
Epoch: 0624 loss = 0.000108
Epoch: 0625 loss = 0.000101
Epoch: 0626 loss = 0.000089
Epoch: 0627 loss = 0.000105
Epoch: 0628 loss = 0.000097
Epoch: 0629 loss = 0.000103
Epoch: 0630 loss = 0.000109
Epoch: 0631 loss = 0.000102
Epoch: 0632 loss = 0.000087
......
Epoch: 0764 loss = 0.000061
Epoch: 0765 loss = 0.000063
Epoch: 0766 loss = 0.000061
Epoch: 0767 loss = 0.000060
Epoch: 0768 loss = 0.000061
Epoch: 0769 loss = 0.000064
Epoch: 0770 loss = 0.000058
Epoch: 0771 loss = 0.000061
Epoch: 0772 loss = 0.000064
Epoch: 0773 loss = 0.000064
Epoch: 0774 loss = 0.000063
Epoch: 0775 loss = 0.000058
Epoch: 0776 loss = 0.000057
Epoch: 0777 loss = 0.000060
Epoch: 0778 loss = 0.000058
Epoch: 0779 loss = 0.000061
Epoch: 0780 loss = 0.000061
Epoch: 0781 loss = 0.000059
Epoch: 0782 loss = 0.000058
Epoch: 0783 loss = 0.000060
Epoch: 0784 loss = 0.000055
Epoch: 0785 loss = 0.000063
Epoch: 0786 loss = 0.000056
Epoch: 0787 loss = 0.000056
Epoch: 0788 loss = 0.000058
Epoch: 0789 loss = 0.000060
Epoch: 0790 loss = 0.000057
Epoch: 0791 loss = 0.000055
Epoch: 0792 loss = 0.000050
Epoch: 0793 loss = 0.000050
Epoch: 0794 loss = 0.000051
Epoch: 0795 loss = 0.000058
Epoch: 0796 loss = 0.000052
Epoch: 0797 loss = 0.000057
Epoch: 0798 loss = 0.000054
Epoch: 0799 loss = 0.000051
Epoch: 0800 loss = 0.000057
Epoch: 0801 loss = 0.000055
Epoch: 0802 loss = 0.000052
Epoch: 0803 loss = 0.000054
Epoch: 0804 loss = 0.000052
Epoch: 0805 loss = 0.000056
Epoch: 0806 loss = 0.000053
Epoch: 0807 loss = 0.000055
Epoch: 0808 loss = 0.000056
Epoch: 0809 loss = 0.000058
Epoch: 0810 loss = 0.000054
Epoch: 0811 loss = 0.000055
Epoch: 0812 loss = 0.000049
Epoch: 0813 loss = 0.000057
Epoch: 0814 loss = 0.000053
Epoch: 0815 loss = 0.000053
Epoch: 0816 loss = 0.000052
Epoch: 0817 loss = 0.000047
Epoch: 0818 loss = 0.000051
Epoch: 0819 loss = 0.000051
Epoch: 0820 loss = 0.000052
Epoch: 0821 loss = 0.000053
Epoch: 0822 loss = 0.000054
Epoch: 0823 loss = 0.000057
Epoch: 0824 loss = 0.000050
Epoch: 0825 loss = 0.000047
Epoch: 0826 loss = 0.000051
Epoch: 0827 loss = 0.000048
Epoch: 0828 loss = 0.000048
Epoch: 0829 loss = 0.000050
Epoch: 0830 loss = 0.000050
Epoch: 0831 loss = 0.000052
Epoch: 0832 loss = 0.000049
Epoch: 0833 loss = 0.000049
Epoch: 0834 loss = 0.000052
Epoch: 0835 loss = 0.000050
Epoch: 0836 loss = 0.000049
Epoch: 0837 loss = 0.000046
Epoch: 0838 loss = 0.000047
Epoch: 0839 loss = 0.000047
Epoch: 0840 loss = 0.000054
Epoch: 0841 loss = 0.000048
Epoch: 0842 loss = 0.000050
Epoch: 0843 loss = 0.000051
Epoch: 0844 loss = 0.000046
Epoch: 0845 loss = 0.000046
Epoch: 0846 loss = 0.000047
Epoch: 0847 loss = 0.000050
Epoch: 0848 loss = 0.000051
Epoch: 0849 loss = 0.000049
Epoch: 0850 loss = 0.000048
Epoch: 0851 loss = 0.000045
Epoch: 0852 loss = 0.000051
Epoch: 0853 loss = 0.000050
Epoch: 0854 loss = 0.000045
Epoch: 0855 loss = 0.000049
Epoch: 0856 loss = 0.000045
Epoch: 0857 loss = 0.000048
Epoch: 0858 loss = 0.000046
Epoch: 0859 loss = 0.000044
Epoch: 0860 loss = 0.000044
Epoch: 0861 loss = 0.000048
Epoch: 0862 loss = 0.000045
Epoch: 0863 loss = 0.000047
Epoch: 0864 loss = 0.000046
Epoch: 0865 loss = 0.000046
Epoch: 0866 loss = 0.000048
Epoch: 0867 loss = 0.000045
Epoch: 0868 loss = 0.000049
Epoch: 0869 loss = 0.000044
Epoch: 0870 loss = 0.000045
Epoch: 0871 loss = 0.000047
Epoch: 0872 loss = 0.000047
Epoch: 0873 loss = 0.000046
Epoch: 0874 loss = 0.000045
Epoch: 0875 loss = 0.000046
Epoch: 0876 loss = 0.000045
Epoch: 0877 loss = 0.000047
Epoch: 0878 loss = 0.000044
Epoch: 0879 loss = 0.000047
Epoch: 0880 loss = 0.000046
Epoch: 0881 loss = 0.000045
Epoch: 0882 loss = 0.000042
Epoch: 0883 loss = 0.000044
Epoch: 0884 loss = 0.000047
Epoch: 0885 loss = 0.000041
Epoch: 0886 loss = 0.000045
Epoch: 0887 loss = 0.000044
Epoch: 0888 loss = 0.000042
Epoch: 0889 loss = 0.000039
......
Epoch: 0938 loss = 0.000034
Epoch: 0939 loss = 0.000037
Epoch: 0940 loss = 0.000039
Epoch: 0941 loss = 0.000042
Epoch: 0942 loss = 0.000037
Epoch: 0943 loss = 0.000036
Epoch: 0944 loss = 0.000039
Epoch: 0945 loss = 0.000036
Epoch: 0946 loss = 0.000039
Epoch: 0947 loss = 0.000037
Epoch: 0948 loss = 0.000037
Epoch: 0949 loss = 0.000038
Epoch: 0950 loss = 0.000037
Epoch: 0951 loss = 0.000041
Epoch: 0952 loss = 0.000036
Epoch: 0953 loss = 0.000037
Epoch: 0954 loss = 0.000039
Epoch: 0955 loss = 0.000037
Epoch: 0956 loss = 0.000038
Epoch: 0957 loss = 0.000036
Epoch: 0958 loss = 0.000039
Epoch: 0959 loss = 0.000035
Epoch: 0960 loss = 0.000038
Epoch: 0961 loss = 0.000039
Epoch: 0962 loss = 0.000038
Epoch: 0963 loss = 0.000038
Epoch: 0964 loss = 0.000036
Epoch: 0965 loss = 0.000035
Epoch: 0966 loss = 0.000034
Epoch: 0967 loss = 0.000037
Epoch: 0968 loss = 0.000036
Epoch: 0969 loss = 0.000035
Epoch: 0970 loss = 0.000035
Epoch: 0971 loss = 0.000038
Epoch: 0972 loss = 0.000036
Epoch: 0973 loss = 0.000036
Epoch: 0974 loss = 0.000037
Epoch: 0975 loss = 0.000034
Epoch: 0976 loss = 0.000036
Epoch: 0977 loss = 0.000033
Epoch: 0978 loss = 0.000037
Epoch: 0979 loss = 0.000035
Epoch: 0980 loss = 0.000035
Epoch: 0981 loss = 0.000034
Epoch: 0982 loss = 0.000035
Epoch: 0983 loss = 0.000034
Epoch: 0984 loss = 0.000033
Epoch: 0985 loss = 0.000034
Epoch: 0986 loss = 0.000037
Epoch: 0987 loss = 0.000033
Epoch: 0988 loss = 0.000035
Epoch: 0989 loss = 0.000034
Epoch: 0990 loss = 0.000035
Epoch: 0991 loss = 0.000035
Epoch: 0992 loss = 0.000034
Epoch: 0993 loss = 0.000032
Epoch: 0994 loss = 0.000037
Epoch: 0995 loss = 0.000035
Epoch: 0996 loss = 0.000037
Epoch: 0997 loss = 0.000034
Epoch: 0998 loss = 0.000034
Epoch: 0999 loss = 0.000037
Epoch: 1000 loss = 0.000036
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505

def greedy_decoder(model, enc_input, start_symbol):
    """
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    这是一个贪婪解码器(Greedy Decoder),其作用是在没有目标序列输入的情况下,通过逐个生成目标输入词来执行推断。
    代码首先对编码器输入进行编码,然后初始化一个空的解码器输入张量dec_input。接下来,代码进入一个循环,不断生成下一个目标词,并将其添加到dec_input中。
    在每一步中,模型接收当前的dec_input和编码器输出,并生成解码器输出dec_outputs。

    然后,经过一个投影层model.projection后得到projected,并利用max函数找到最大概率对应的词,将其作为下一个目标词。
    如果下一个目标词是句号("."),则终止循环,否则继续生成下一个目标词。

    需要注意的是,在代码中有一些技术细节,比如使用torch.cat函数进行张量拼接、detach方法从张量中分离数据、squeeze方法移除单个维度等,
    这些操作都是为了确保张量的形状和数据类型满足模型的要求。

    最后,代码返回生成的目标输入dec_input。
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    # 创建了一个尺寸为(1, 0)的张量,并使用与enc_input相同的数据类型来初始化张量。
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:
        # torch.cat函数来进行张量的拼接操作。
        # dec_input.detach()表示从dec_input中分离出数据并创建一个新的张量
        # torch.tensor([[next_symbol]],dtype=enc_input.dtype).cuda()表示将next_symbol转换为PyTorch张量并移动到GPU上
        dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        # squeeze的作用是移除张量中的单个维度
        # .max(dim=-1, keepdim=False)[1]:这是max pooling操作。max函数的目的是找到张量在某个维度上的最大值。在这里,我们在最后一个维度(由dim=-1指定)上进行了这个操作。
        # keepdim=False表示我们不希望保留这个被操作的维度。
        # 所以,这行代码将返回一个与projected.squeeze(0)相同形状的张量,但是在最后一个维度上,每个元素都被替换为该维度上的最大值。
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:
            terminal = True
        print(next_word)
    return dec_input
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
'
运行
# Test
enc_inputs, _, _ = next(iter(loader))
enc_inputs = enc_inputs.cuda()
for i in range(len(enc_inputs)):
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
    predict = predict.data.max(1, keepdim=True)[1]
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(5, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke', '.']
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer', '.']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/908558
推荐阅读
相关标签
  

闽ICP备14008679号