当前位置:   article > 正文

大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(1)-CSDN博客

大模型基础——从零实现一个Transformer(2)-CSDN博客

大模型基础——从零实现一个Transformer(3)-CSDN博客

大模型基础——从零实现一个Transformer(4)-CSDN博客


一、前言

上一篇文章已经把Encoder模块和Decoder模块都已经实现了,

接下来来实现完整的Transformer

二、Transformer

Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠

  1. import torch
  2. from torch import nn,Tensor
  3. from torch.nn import Embedding
  4. #引入自己实现的模块
  5. from llm_base.embedding.PositionalEncoding import PositionalEmbedding
  6. from llm_base.encoder import Encoder
  7. from llm_base.decoder import Decoder
  8. from llm_base.mask.target_mask import make_target_mask
  9. class Transformer(nn.Module):
  10. def __init__(self,
  11. source_vocab_size:int,
  12. target_vocab_size:int,
  13. d_model: int = 512,
  14. n_heads: int = 8,
  15. num_encoder_layers: int = 6,
  16. num_decoder_layers: int = 6,
  17. d_ff: int = 2048,
  18. dropout: float = 0.1,
  19. max_positions:int = 5000,
  20. pad_idx: int = 0,
  21. norm_first: bool=False) -> None:
  22. '''
  23. :param source_vocab_size: size of the source vocabulary.
  24. :param target_vocab_size: size of the target vocabulary.
  25. :param d_model: dimension of embeddings. Defaults to 512.
  26. :param n_heads: number of heads. Defaults to 8.
  27. :param num_encoder_layers: number of encoder blocks. Defaults to 6.
  28. :param num_decoder_layers: number of decoder blocks. Defaults to 6.
  29. :param d_ff: dimension of inner feed-forward network. Defaults to 2048.
  30. :param dropout: dropout ratio. Defaults to 0.1.
  31. :param max_positions: maximum sequence length for positional encoding. Defaults to 5000.
  32. :param pad_idx: pad index. Defaults to 0.
  33. :param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).
  34. Otherwise it's done after(Post-Norm). Default to False.
  35. '''
  36. super().__init__()
  37. # Token embedding
  38. self.src_embeddings = Embedding(source_vocab_size,d_model)
  39. self.target_embeddings = Embedding(target_vocab_size,d_model)
  40. # Position embedding
  41. self.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
  42. self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
  43. # 编码层定义
  44. self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)
  45. # 解码层定义
  46. self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)
  47. self.pad_idx = pad_idx
  48. def encode(self,
  49. src:Tensor,
  50. src_mask: Tensor=None,
  51. keep_attentions: bool=False) -> Tensor:
  52. '''
  53. 编码过程
  54. :param src: (batch_size, src_seq_length) the sequence to the encoder
  55. :param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence
  56. :param keep_attentions: whether keep attention weigths or not. Defaults to False.
  57. :return: (batch_size, seq_length, d_model) encoder output
  58. '''
  59. src_embedding_tensor = self.src_embeddings(src)
  60. src_embedded = self.encoder_pos(src_embedding_tensor)
  61. return self.encoder(src_embedded,src_mask,keep_attentions)
  62. def decode(self,
  63. target_tensor: Tensor,
  64. memory: Tensor,
  65. target_mask: Tensor = None,
  66. memory_mask: Tensor = None,
  67. keep_attentions: bool = False) ->Tensor:
  68. '''
  69. :param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.
  70. :param memory: (batch_size, src_seq_length, d_model) the sequence from the last layer of the encoder.
  71. :param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.
  72. :param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.
  73. :param keep_attentions: whether keep attention weigths or not. Defaults to False.
  74. :return: output (batch_size, tgt_seq_length, tgt_vocab_size)
  75. '''
  76. target_embedding_tensor = self.target_embeddings(target_tensor)
  77. target_embedded = self.decoder_pos(target_embedding_tensor)
  78. # logits (batch_size, target_seq_length, d_model)
  79. logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)
  80. return logits
  81. def forward(self,
  82. src: Tensor,
  83. target: Tensor,
  84. src_mask: Tensor=None,
  85. target_mask: Tensor=None,
  86. keep_attention:bool=False)->Tensor:
  87. '''
  88. :param src: (batch_size, src_seq_length) the sequence to the encoder
  89. :param target: (batch_size, tgt_seq_length) the sequence to the decoder
  90. :param src_mask:
  91. :param target_mask:
  92. :param keep_attention: whether keep attention weigths or not. Defaults to False.
  93. :return: (batch_size, tgt_seq_length, tgt_vocab_size)
  94. '''
  95. memory = self.encode(src,src_mask,keep_attention)
  96. return self.decode(target,memory,target_mask,src_mask,keep_attention)

三、测试

写个简单的main函数,测试一下整体网络是否正常

  1. if __name__ == '__main__':
  2. source_vocab_size = 300
  3. target_vocab_size = 300
  4. # padding对应的index,一般都是0
  5. pad_idx = 0
  6. batch_size = 1
  7. max_positions = 20
  8. model = Transformer(source_vocab_size=source_vocab_size,
  9. target_vocab_size=target_vocab_size)
  10. src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
  11. target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
  12. ## 最后5位置是padding
  13. src_tensor[:,-5:] = 0
  14. ## 最后10位置是padding
  15. target_tensor[:, -10:] = 0
  16. src_mask = (src_tensor != pad_idx).unsqueeze(1)
  17. targe_mask = make_target_mask(target_tensor)
  18. logits = model(src_tensor,target_tensor,src_mask,targe_mask)
  19. print(logits.shape)
  20. #torch.Size([1, 20, 512])
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/735861
推荐阅读
相关标签
  

闽ICP备14008679号