赞
踩
大模型基础——从零实现一个Transformer(1)-CSDN博客
大模型基础——从零实现一个Transformer(2)-CSDN博客
大模型基础——从零实现一个Transformer(3)-CSDN博客
大模型基础——从零实现一个Transformer(4)-CSDN博客
上一篇文章已经把Encoder模块和Decoder模块都已经实现了,
接下来来实现完整的Transformer
Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠
- import torch
- from torch import nn,Tensor
- from torch.nn import Embedding
-
- #引入自己实现的模块
- from llm_base.embedding.PositionalEncoding import PositionalEmbedding
- from llm_base.encoder import Encoder
- from llm_base.decoder import Decoder
- from llm_base.mask.target_mask import make_target_mask
-
- class Transformer(nn.Module):
- def __init__(self,
- source_vocab_size:int,
- target_vocab_size:int,
- d_model: int = 512,
- n_heads: int = 8,
- num_encoder_layers: int = 6,
- num_decoder_layers: int = 6,
- d_ff: int = 2048,
- dropout: float = 0.1,
- max_positions:int = 5000,
- pad_idx: int = 0,
- norm_first: bool=False) -> None:
- '''
- :param source_vocab_size: size of the source vocabulary.
- :param target_vocab_size: size of the target vocabulary.
- :param d_model: dimension of embeddings. Defaults to 512.
- :param n_heads: number of heads. Defaults to 8.
- :param num_encoder_layers: number of encoder blocks. Defaults to 6.
- :param num_decoder_layers: number of decoder blocks. Defaults to 6.
- :param d_ff: dimension of inner feed-forward network. Defaults to 2048.
- :param dropout: dropout ratio. Defaults to 0.1.
- :param max_positions: maximum sequence length for positional encoding. Defaults to 5000.
- :param pad_idx: pad index. Defaults to 0.
- :param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).
- Otherwise it's done after(Post-Norm). Default to False.
- '''
- super().__init__()
- # Token embedding
- self.src_embeddings = Embedding(source_vocab_size,d_model)
- self.target_embeddings = Embedding(target_vocab_size,d_model)
- # Position embedding
- self.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
- self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
- # 编码层定义
- self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)
- # 解码层定义
- self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)
- self.pad_idx = pad_idx
- def encode(self,
- src:Tensor,
- src_mask: Tensor=None,
- keep_attentions: bool=False) -> Tensor:
- '''
- 编码过程
- :param src: (batch_size, src_seq_length) the sequence to the encoder
- :param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence
- :param keep_attentions: whether keep attention weigths or not. Defaults to False.
- :return: (batch_size, seq_length, d_model) encoder output
- '''
- src_embedding_tensor = self.src_embeddings(src)
- src_embedded = self.encoder_pos(src_embedding_tensor)
- return self.encoder(src_embedded,src_mask,keep_attentions)
- def decode(self,
- target_tensor: Tensor,
- memory: Tensor,
- target_mask: Tensor = None,
- memory_mask: Tensor = None,
- keep_attentions: bool = False) ->Tensor:
- '''
-
- :param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.
- :param memory: (batch_size, src_seq_length, d_model) the sequence from the last layer of the encoder.
- :param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.
- :param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.
- :param keep_attentions: whether keep attention weigths or not. Defaults to False.
- :return: output (batch_size, tgt_seq_length, tgt_vocab_size)
- '''
- target_embedding_tensor = self.target_embeddings(target_tensor)
- target_embedded = self.decoder_pos(target_embedding_tensor)
- # logits (batch_size, target_seq_length, d_model)
- logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)
- return logits
- def forward(self,
- src: Tensor,
- target: Tensor,
- src_mask: Tensor=None,
- target_mask: Tensor=None,
- keep_attention:bool=False)->Tensor:
- '''
-
- :param src: (batch_size, src_seq_length) the sequence to the encoder
- :param target: (batch_size, tgt_seq_length) the sequence to the decoder
- :param src_mask:
- :param target_mask:
- :param keep_attention: whether keep attention weigths or not. Defaults to False.
- :return: (batch_size, tgt_seq_length, tgt_vocab_size)
- '''
- memory = self.encode(src,src_mask,keep_attention)
- return self.decode(target,memory,target_mask,src_mask,keep_attention)
写个简单的main函数,测试一下整体网络是否正常
- if __name__ == '__main__':
- source_vocab_size = 300
- target_vocab_size = 300
- # padding对应的index,一般都是0
- pad_idx = 0
-
- batch_size = 1
- max_positions = 20
-
- model = Transformer(source_vocab_size=source_vocab_size,
- target_vocab_size=target_vocab_size)
-
- src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
- target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
-
- ## 最后5位置是padding
- src_tensor[:,-5:] = 0
-
- ## 最后10位置是padding
- target_tensor[:, -10:] = 0
-
- src_mask = (src_tensor != pad_idx).unsqueeze(1)
- targe_mask = make_target_mask(target_tensor)
-
- logits = model(src_tensor,target_tensor,src_mask,targe_mask)
- print(logits.shape)
- #torch.Size([1, 20, 512])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。