当前位置:   article > 正文

Transformer代码详解_matlab transformer

matlab transformer

本教程适用于对Transformer理论有一定理解的朋友。理论部分请看其他教程,本文详解代码。

Embedding

Embedding很好理解,vocab表示词表大小,d_model表示embedding大小。至于返回值为什么乘上sqrt(self.d_model) 目前还不是很理解。

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

位置编码

因为子注意力机制在算注意力权重的时候,并没有考虑到词语前后关系,而是考虑了整体的上下文,因此需要加入位置编码。主要的数学公式如下所示:
在这里插入图片描述
pos可以理解成每个字符在一句话的位置,i可以理解为在embedding向量的位置。这里假设betch_size设置为1,那么一句话的矩阵表达就是:[seq_len, d_model]。在我们一开始得到了embedding后的矩阵,需要再加上PositionalEncoding矩阵,它的维度也是[seq_len, d_model],下面让我们看看如何得到这个矩阵。

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.pe = torch.zeros(max_len, d_model)
        # [man_len, 1]
        position = torch.arange(0, max_len).unsqueeze(1)
        # (d_model/2, )
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        # pe: [max_len, d_model]  矩阵的扩充运算
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        # pe : [1, max_len, d_model] 第一个维度是batch_size
        self.pe = self.pe.unsqueeze(0)

    def forward(self, x):
        # 输入的x维度: [batch_size, seq_len, d_model]
        # 因为输入的句子的长度会比设定的max_len小,因为需要切片操作
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

画出图形看一下形状:

import matplotlib.pyplot as plt
plt.figure(figsize=(15,5))
model = PositionalEncoding(20, 0)
x = torch.zeros(1, 100, 20)
plt.plot(range(100), model(x)[0,:, 4:8])
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

可以看到周期性的变化,这让每个位置的值都有了自己的position
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/653456
推荐阅读
相关标签
  

闽ICP备14008679号