当前位置:   article > 正文

Transformer的Encoder和Decoder之间的交互

Transformer的Encoder和Decoder之间的交互

Transformer的Encoder和Decoder之间的交互

flyfish

这个示例代码创建了一个小的Transformer模型,并演示了如何在Encoder和Decoder之间进行交互。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 定义位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

# 定义Transformer模型
class TransformerModel(nn.Module):
    def __init__(self, input_dim, model_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward=512, max_len=5000):
        super(TransformerModel, self).__init__()
        self.model_dim = model_dim
        self.embedding = nn.Embedding(input_dim, model_dim)
        self.positional_encoding = PositionalEncoding(model_dim, max_len)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(model_dim, nhead, dim_feedforward), num_encoder_layers)

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(model_dim, nhead, dim_feedforward), num_decoder_layers)

        self.fc_out = nn.Linear(model_dim, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        src_emb = self.positional_encoding(self.embedding(src) * math.sqrt(self.model_dim))
        tgt_emb = self.positional_encoding(self.embedding(tgt) * math.sqrt(self.model_dim))

        memory = self.encoder(src_emb, mask=src_mask)
        output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        return self.fc_out(output)

# 超参数定义
input_dim = 1000
model_dim = 512
output_dim = 1000
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
max_len = 5000

# 创建Transformer模型实例
model = TransformerModel(input_dim, model_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_len)

# 定义示例输入
src = torch.randint(0, input_dim, (10, 32))  # (source sequence length, batch size)
tgt = torch.randint(0, input_dim, (20, 32))  # (target sequence length, batch size)

# 前向传播
output = model(src, tgt)

# 打印输出形状
print(output.shape)  # (target sequence length, batch size, output dimension)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

输出

torch.Size([20, 32, 1000])
  • 1

位置编码 (Positional Encoding):
用于在输入中加入位置信息,使模型能够考虑序列顺序。

Transformer模型 (TransformerModel):
包括Embedding层、位置编码层、Encoder、Decoder和输出的全连接层。

前向传播:
将输入源序列 (src) 和目标序列 (tgt) 通过嵌入层和位置编码。
使用Encoder对源序列进行编码,得到记忆 (memory)。
使用Decoder对目标序列进行解码,结合记忆生成输出。

超参数:
定义模型的维度、头数、层数等。

register_buffer 是 PyTorch 中 nn.Module 类的方法,用于注册一个持久的缓冲区,这些缓冲区不是模型的参数,但在训练和推理过程中需要被保存和加载。例如,位置编码就是这样一种缓冲区,它不需要进行梯度更新,但需要在模型保存和加载时保持不变。

下面是一个简单的例子,展示如何使用 register_buffer 注册一个缓冲区:

import torch
import torch.nn as nn

class ExampleModule(nn.Module):
    def __init__(self):
        super(ExampleModule, self).__init__()
        # 注册一个缓冲区
        buffer = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
        self.register_buffer('my_buffer', buffer)
        # 一个简单的线性层
        self.linear = nn.Linear(4, 2)

    def forward(self, x):
        # 使用缓冲区进行一些操作
        x = x + self.my_buffer
        return self.linear(x)

# 创建模型实例
model = ExampleModule()

# 打印模型结构
print(model)

# 定义输入张量
input_tensor = torch.tensor([1, 1, 1, 1], dtype=torch.float32)

# 前向传播
output = model(input_tensor)
print(output)

# 打印缓冲区
print("Buffer:", model.my_buffer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

输出

ExampleModule(
  (linear): Linear(in_features=4, out_features=2, bias=True)
)
tensor([-3.2452,  0.5913], grad_fn=<ViewBackward0>)
Buffer: tensor([1., 2., 3., 4.])
  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/681321
推荐阅读
相关标签
  

闽ICP备14008679号