赞
踩
- # 使用EncoderDecoder类来实现编码器-解码器结构
- class EncoderDecoder(nn.Module):
- def __init__(self, encoder, decoder, source_embed, target_embed, generator):
- """初始化函数中有5个参数, 分别是编码器对象, 解码器对象,
- 源数据嵌入函数, 目标数据嵌入函数, 以及输出部分的类别生成器对象
- """
- super(EncoderDecoder, self).__init__()
- # 将参数传入到类中
- self.encoder = encoder
- self.decoder = decoder
- self.src_embed = source_embed
- self.tgt_embed = target_embed
- self.generator = generator
-
- def forward(self, source, target, source_mask, target_mask):
- """在forward函数中,有四个参数, source代表源数据, target代表目标数据,
- source_mask和target_mask代表对应的掩码张量"""
-
- # 在函数中, 将source, source_mask传入编码函数, 得到结果后,
- # 与source_mask,target,和target_mask一同传给解码函数.
- return self.decode(self.encode(source, source_mask), source_mask,
- target, target_mask)
-
- def encode(self, source, source_mask):
- """编码函数, 以source和source_mask为参数"""
- # 使用src_embed对source做处理, 然后和source_mask一起传给self.encoder
- return self.encoder(self.src_embed(source), source_mask)
-
- def decode(self, memory, source_mask, target, target_mask):
- """解码函数, 以memory即编码器的输出, source_mask, target, target_mask为参数"""
- # 使用tgt_embed对target做处理, 然后和source_mask, target_mask, memory一起传给self.decoder
- return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)
- 实例化参数
- vocab_size = 1000
- d_model = 512
- encoder = en
- decoder = de
- source_embed = nn.Embedding(vocab_size, d_model)
- target_embed = nn.Embedding(vocab_size, d_model)
- generator = gen
- 输入参数:
- # 假设源数据与目标数据相同, 实际中并不相同
- source = target = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))
-
- # 假设src_mask与tgt_mask相同,实际中并不相同
- source_mask = target_mask = Variable(torch.zeros(8, 4, 4))
- 调用:
- ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
- ed_result = ed(source, target, source_mask, target_mask)
- print(ed_result)
- print(ed_result.shape)
- 输出效果:
- tensor([[[ 0.2102, -0.0826, -0.0550, ..., 1.5555, 1.3025, -0.6296],
- [ 0.8270, -0.5372, -0.9559, ..., 0.3665, 0.4338, -0.7505],
- [ 0.4956, -0.5133, -0.9323, ..., 1.0773, 1.1913, -0.6240],
- [ 0.5770, -0.6258, -0.4833, ..., 0.1171, 1.0069, -1.9030]],
-
- [[-0.4355, -1.7115, -1.5685, ..., -0.6941, -0.1878, -0.1137],
- [-0.8867, -1.2207, -1.4151, ..., -0.9618, 0.1722, -0.9562],
- [-0.0946, -0.9012, -1.6388, ..., -0.2604, -0.3357, -0.6436],
- [-1.1204, -1.4481, -1.5888, ..., -0.8816, -0.6497, 0.0606]]],
- grad_fn=<AddBackward0>)
- torch.Size([2, 4, 512])
- def make_model(source_vocab, target_vocab, N=6,
- d_model=512, d_ff=2048, head=8, dropout=0.1):
- """该函数用来构建模型, 有7个参数,分别是源数据特征(词汇)总数,目标数据特征(词汇)总数,
- 编码器和解码器堆叠数,词向量映射维度,前馈全连接网络中变换矩阵的维度,
- 多头注意力结构中的多头数,以及置零比率dropout."""
-
- # 首先得到一个深度拷贝命令,接下来很多结构都需要进行深度拷贝,
- # 来保证他们彼此之间相互独立,不受干扰.
- c = copy.deepcopy
-
- # 实例化了多头注意力类,得到对象attn
- attn = MultiHeadedAttention(head, d_model)
-
- # 然后实例化前馈全连接类,得到对象ff
- ff = PositionwiseFeedForward(d_model, d_ff, dropout)
-
- # 实例化位置编码类,得到对象position
- position = PositionalEncoding(d_model, dropout)
-
- # 根据结构图, 最外层是EncoderDecoder,在EncoderDecoder中,
- # 分别是编码器层,解码器层,源数据Embedding层和位置编码组成的有序结构,
- # 目标数据Embedding层和位置编码组成的有序结构,以及类别生成器层.
- # 在编码器层中有attention子层以及前馈全连接子层,
- # 在解码器层中有两个attention子层以及前馈全连接层.
- model = EncoderDecoder(
- Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
- Decoder(DecoderLayer(d_model, c(attn), c(attn),
- c(ff), dropout), N),
- nn.Sequential(Embeddings(d_model, source_vocab), c(position)),
- nn.Sequential(Embeddings(d_model, target_vocab), c(position)),
- Generator(d_model, target_vocab))
-
- # 模型结构完成后,接下来就是初始化模型中的参数,比如线性层中的变换矩阵
- # 这里一但判断参数的维度大于1,则会将其初始化成一个服从均匀分布的矩阵,
- for p in model.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform(p)
- return model
- # 结果服从均匀分布U(-a, a)
- >>> w = torch.empty(3, 5)
- >>> w = nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
- >>> w
- tensor([[-0.7742, 0.5413, 0.5478, -0.4806, -0.2555],
- [-0.8358, 0.4673, 0.3012, 0.3882, -0.6375],
- [ 0.4622, -0.0794, 0.1851, 0.8462, -0.3591]])
- 输入参数:
- source_vocab = 11
- target_vocab = 11
- N = 6
- # 其他参数都使用默认值
- 调用:
- if __name__ == '__main__':
- res = make_model(source_vocab, target_vocab, N)
- print(res)
- 输出效果:
- # 根据Transformer结构图构建的最终模型结构
- EncoderDecoder(
- (encoder): Encoder(
- (layers): ModuleList(
- (0): EncoderLayer(
- (self_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (feed_forward): PositionwiseFeedForward(
- (w_1): Linear(in_features=512, out_features=2048)
- (w_2): Linear(in_features=2048, out_features=512)
- (dropout): Dropout(p=0.1)
- )
- (sublayer): ModuleList(
- (0): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (1): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- )
- )
- (1): EncoderLayer(
- (self_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (feed_forward): PositionwiseFeedForward(
- (w_1): Linear(in_features=512, out_features=2048)
- (w_2): Linear(in_features=2048, out_features=512)
- (dropout): Dropout(p=0.1)
- )
- (sublayer): ModuleList(
- (0): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (1): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- )
- )
- )
- (norm): LayerNorm(
- )
- )
- (decoder): Decoder(
- (layers): ModuleList(
- (0): DecoderLayer(
- (self_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (src_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (feed_forward): PositionwiseFeedForward(
- (w_1): Linear(in_features=512, out_features=2048)
- (w_2): Linear(in_features=2048, out_features=512)
- (dropout): Dropout(p=0.1)
- )
- (sublayer): ModuleList(
- (0): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (1): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (2): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- )
- )
- (1): DecoderLayer(
- (self_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (src_attn): MultiHeadedAttention(
- (linears): ModuleList(
- (0): Linear(in_features=512, out_features=512)
- (1): Linear(in_features=512, out_features=512)
- (2): Linear(in_features=512, out_features=512)
- (3): Linear(in_features=512, out_features=512)
- )
- (dropout): Dropout(p=0.1)
- )
- (feed_forward): PositionwiseFeedForward(
- (w_1): Linear(in_features=512, out_features=2048)
- (w_2): Linear(in_features=2048, out_features=512)
- (dropout): Dropout(p=0.1)
- )
- (sublayer): ModuleList(
- (0): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (1): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- (2): SublayerConnection(
- (norm): LayerNorm(
- )
- (dropout): Dropout(p=0.1)
- )
- )
- )
- )
- (norm): LayerNorm(
- )
- )
- (src_embed): Sequential(
- (0): Embeddings(
- (lut): Embedding(11, 512)
- )
- (1): PositionalEncoding(
- (dropout): Dropout(p=0.1)
- )
- )
- (tgt_embed): Sequential(
- (0): Embeddings(
- (lut): Embedding(11, 512)
- )
- (1): PositionalEncoding(
- (dropout): Dropout(p=0.1)
- )
- )
- (generator): Generator(
- (proj): Linear(in_features=512, out_features=11)
- )
- )
学习并实现了编码器-解码器结构的类: EncoderDecoder
学习并实现了模型构建函数: make_model
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。