当前位置:   article > 正文

transformer模型构建_transfomer模型构建

transfomer模型构建

2.6 模型构建


学习目标

  • 掌握编码器-解码器结构的实现过程.
  • 掌握Transformer模型的构建过程.

  • 通过上面的小节, 我们已经完成了所有组成部分的实现, 接下来就来实现完整的编码器-解码器结构.

  • Transformer总体架构图:


编码器-解码器结构的代码实现

  1. # 使用EncoderDecoder类来实现编码器-解码器结构
  2. class EncoderDecoder(nn.Module):
  3. def __init__(self, encoder, decoder, source_embed, target_embed, generator):
  4. """初始化函数中有5个参数, 分别是编码器对象, 解码器对象,
  5. 源数据嵌入函数, 目标数据嵌入函数, 以及输出部分的类别生成器对象
  6. """
  7. super(EncoderDecoder, self).__init__()
  8. # 将参数传入到类中
  9. self.encoder = encoder
  10. self.decoder = decoder
  11. self.src_embed = source_embed
  12. self.tgt_embed = target_embed
  13. self.generator = generator
  14. def forward(self, source, target, source_mask, target_mask):
  15. """在forward函数中,有四个参数, source代表源数据, target代表目标数据,
  16. source_mask和target_mask代表对应的掩码张量"""
  17. # 在函数中, 将source, source_mask传入编码函数, 得到结果后,
  18. # 与source_mask,target,和target_mask一同传给解码函数.
  19. return self.decode(self.encode(source, source_mask), source_mask,
  20. target, target_mask)
  21. def encode(self, source, source_mask):
  22. """编码函数, 以source和source_mask为参数"""
  23. # 使用src_embed对source做处理, 然后和source_mask一起传给self.encoder
  24. return self.encoder(self.src_embed(source), source_mask)
  25. def decode(self, memory, source_mask, target, target_mask):
  26. """解码函数, 以memory即编码器的输出, source_mask, target, target_mask为参数"""
  27. # 使用tgt_embed对target做处理, 然后和source_mask, target_mask, memory一起传给self.decoder
  28. return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

  • 实例化参数
  1. vocab_size = 1000
  2. d_model = 512
  3. encoder = en
  4. decoder = de
  5. source_embed = nn.Embedding(vocab_size, d_model)
  6. target_embed = nn.Embedding(vocab_size, d_model)
  7. generator = gen

  • 输入参数:
  1. # 假设源数据与目标数据相同, 实际中并不相同
  2. source = target = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))
  3. # 假设src_mask与tgt_mask相同,实际中并不相同
  4. source_mask = target_mask = Variable(torch.zeros(8, 4, 4))

  • 调用:
  1. ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
  2. ed_result = ed(source, target, source_mask, target_mask)
  3. print(ed_result)
  4. print(ed_result.shape)

  • 输出效果:
  1. tensor([[[ 0.2102, -0.0826, -0.0550, ..., 1.5555, 1.3025, -0.6296],
  2. [ 0.8270, -0.5372, -0.9559, ..., 0.3665, 0.4338, -0.7505],
  3. [ 0.4956, -0.5133, -0.9323, ..., 1.0773, 1.1913, -0.6240],
  4. [ 0.5770, -0.6258, -0.4833, ..., 0.1171, 1.0069, -1.9030]],
  5. [[-0.4355, -1.7115, -1.5685, ..., -0.6941, -0.1878, -0.1137],
  6. [-0.8867, -1.2207, -1.4151, ..., -0.9618, 0.1722, -0.9562],
  7. [-0.0946, -0.9012, -1.6388, ..., -0.2604, -0.3357, -0.6436],
  8. [-1.1204, -1.4481, -1.5888, ..., -0.8816, -0.6497, 0.0606]]],
  9. grad_fn=<AddBackward0>)
  10. torch.Size([2, 4, 512])

  • 接着将基于以上结构构建用于训练的模型.

Tansformer模型构建过程的代码分析

  1. def make_model(source_vocab, target_vocab, N=6,
  2. d_model=512, d_ff=2048, head=8, dropout=0.1):
  3. """该函数用来构建模型, 有7个参数,分别是源数据特征(词汇)总数,目标数据特征(词汇)总数,
  4. 编码器和解码器堆叠数,词向量映射维度,前馈全连接网络中变换矩阵的维度,
  5. 多头注意力结构中的多头数,以及置零比率dropout."""
  6. # 首先得到一个深度拷贝命令,接下来很多结构都需要进行深度拷贝,
  7. # 来保证他们彼此之间相互独立,不受干扰.
  8. c = copy.deepcopy
  9. # 实例化了多头注意力类,得到对象attn
  10. attn = MultiHeadedAttention(head, d_model)
  11. # 然后实例化前馈全连接类,得到对象ff
  12. ff = PositionwiseFeedForward(d_model, d_ff, dropout)
  13. # 实例化位置编码类,得到对象position
  14. position = PositionalEncoding(d_model, dropout)
  15. # 根据结构图, 最外层是EncoderDecoder,在EncoderDecoder中,
  16. # 分别是编码器层,解码器层,源数据Embedding层和位置编码组成的有序结构,
  17. # 目标数据Embedding层和位置编码组成的有序结构,以及类别生成器层.
  18. # 在编码器层中有attention子层以及前馈全连接子层,
  19. # 在解码器层中有两个attention子层以及前馈全连接层.
  20. model = EncoderDecoder(
  21. Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
  22. Decoder(DecoderLayer(d_model, c(attn), c(attn),
  23. c(ff), dropout), N),
  24. nn.Sequential(Embeddings(d_model, source_vocab), c(position)),
  25. nn.Sequential(Embeddings(d_model, target_vocab), c(position)),
  26. Generator(d_model, target_vocab))
  27. # 模型结构完成后,接下来就是初始化模型中的参数,比如线性层中的变换矩阵
  28. # 这里一但判断参数的维度大于1,则会将其初始化成一个服从均匀分布的矩阵,
  29. for p in model.parameters():
  30. if p.dim() > 1:
  31. nn.init.xavier_uniform(p)
  32. return model

  • nn.init.xavier_uniform演示:
  1. # 结果服从均匀分布U(-a, a)
  2. >>> w = torch.empty(3, 5)
  3. >>> w = nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
  4. >>> w
  5. tensor([[-0.7742, 0.5413, 0.5478, -0.4806, -0.2555],
  6. [-0.8358, 0.4673, 0.3012, 0.3882, -0.6375],
  7. [ 0.4622, -0.0794, 0.1851, 0.8462, -0.3591]])

  • 输入参数:
  1. source_vocab = 11
  2. target_vocab = 11
  3. N = 6
  4. # 其他参数都使用默认值

  • 调用:
  1. if __name__ == '__main__':
  2. res = make_model(source_vocab, target_vocab, N)
  3. print(res)

  • 输出效果:
  1. # 根据Transformer结构图构建的最终模型结构
  2. EncoderDecoder(
  3. (encoder): Encoder(
  4. (layers): ModuleList(
  5. (0): EncoderLayer(
  6. (self_attn): MultiHeadedAttention(
  7. (linears): ModuleList(
  8. (0): Linear(in_features=512, out_features=512)
  9. (1): Linear(in_features=512, out_features=512)
  10. (2): Linear(in_features=512, out_features=512)
  11. (3): Linear(in_features=512, out_features=512)
  12. )
  13. (dropout): Dropout(p=0.1)
  14. )
  15. (feed_forward): PositionwiseFeedForward(
  16. (w_1): Linear(in_features=512, out_features=2048)
  17. (w_2): Linear(in_features=2048, out_features=512)
  18. (dropout): Dropout(p=0.1)
  19. )
  20. (sublayer): ModuleList(
  21. (0): SublayerConnection(
  22. (norm): LayerNorm(
  23. )
  24. (dropout): Dropout(p=0.1)
  25. )
  26. (1): SublayerConnection(
  27. (norm): LayerNorm(
  28. )
  29. (dropout): Dropout(p=0.1)
  30. )
  31. )
  32. )
  33. (1): EncoderLayer(
  34. (self_attn): MultiHeadedAttention(
  35. (linears): ModuleList(
  36. (0): Linear(in_features=512, out_features=512)
  37. (1): Linear(in_features=512, out_features=512)
  38. (2): Linear(in_features=512, out_features=512)
  39. (3): Linear(in_features=512, out_features=512)
  40. )
  41. (dropout): Dropout(p=0.1)
  42. )
  43. (feed_forward): PositionwiseFeedForward(
  44. (w_1): Linear(in_features=512, out_features=2048)
  45. (w_2): Linear(in_features=2048, out_features=512)
  46. (dropout): Dropout(p=0.1)
  47. )
  48. (sublayer): ModuleList(
  49. (0): SublayerConnection(
  50. (norm): LayerNorm(
  51. )
  52. (dropout): Dropout(p=0.1)
  53. )
  54. (1): SublayerConnection(
  55. (norm): LayerNorm(
  56. )
  57. (dropout): Dropout(p=0.1)
  58. )
  59. )
  60. )
  61. )
  62. (norm): LayerNorm(
  63. )
  64. )
  65. (decoder): Decoder(
  66. (layers): ModuleList(
  67. (0): DecoderLayer(
  68. (self_attn): MultiHeadedAttention(
  69. (linears): ModuleList(
  70. (0): Linear(in_features=512, out_features=512)
  71. (1): Linear(in_features=512, out_features=512)
  72. (2): Linear(in_features=512, out_features=512)
  73. (3): Linear(in_features=512, out_features=512)
  74. )
  75. (dropout): Dropout(p=0.1)
  76. )
  77. (src_attn): MultiHeadedAttention(
  78. (linears): ModuleList(
  79. (0): Linear(in_features=512, out_features=512)
  80. (1): Linear(in_features=512, out_features=512)
  81. (2): Linear(in_features=512, out_features=512)
  82. (3): Linear(in_features=512, out_features=512)
  83. )
  84. (dropout): Dropout(p=0.1)
  85. )
  86. (feed_forward): PositionwiseFeedForward(
  87. (w_1): Linear(in_features=512, out_features=2048)
  88. (w_2): Linear(in_features=2048, out_features=512)
  89. (dropout): Dropout(p=0.1)
  90. )
  91. (sublayer): ModuleList(
  92. (0): SublayerConnection(
  93. (norm): LayerNorm(
  94. )
  95. (dropout): Dropout(p=0.1)
  96. )
  97. (1): SublayerConnection(
  98. (norm): LayerNorm(
  99. )
  100. (dropout): Dropout(p=0.1)
  101. )
  102. (2): SublayerConnection(
  103. (norm): LayerNorm(
  104. )
  105. (dropout): Dropout(p=0.1)
  106. )
  107. )
  108. )
  109. (1): DecoderLayer(
  110. (self_attn): MultiHeadedAttention(
  111. (linears): ModuleList(
  112. (0): Linear(in_features=512, out_features=512)
  113. (1): Linear(in_features=512, out_features=512)
  114. (2): Linear(in_features=512, out_features=512)
  115. (3): Linear(in_features=512, out_features=512)
  116. )
  117. (dropout): Dropout(p=0.1)
  118. )
  119. (src_attn): MultiHeadedAttention(
  120. (linears): ModuleList(
  121. (0): Linear(in_features=512, out_features=512)
  122. (1): Linear(in_features=512, out_features=512)
  123. (2): Linear(in_features=512, out_features=512)
  124. (3): Linear(in_features=512, out_features=512)
  125. )
  126. (dropout): Dropout(p=0.1)
  127. )
  128. (feed_forward): PositionwiseFeedForward(
  129. (w_1): Linear(in_features=512, out_features=2048)
  130. (w_2): Linear(in_features=2048, out_features=512)
  131. (dropout): Dropout(p=0.1)
  132. )
  133. (sublayer): ModuleList(
  134. (0): SublayerConnection(
  135. (norm): LayerNorm(
  136. )
  137. (dropout): Dropout(p=0.1)
  138. )
  139. (1): SublayerConnection(
  140. (norm): LayerNorm(
  141. )
  142. (dropout): Dropout(p=0.1)
  143. )
  144. (2): SublayerConnection(
  145. (norm): LayerNorm(
  146. )
  147. (dropout): Dropout(p=0.1)
  148. )
  149. )
  150. )
  151. )
  152. (norm): LayerNorm(
  153. )
  154. )
  155. (src_embed): Sequential(
  156. (0): Embeddings(
  157. (lut): Embedding(11, 512)
  158. )
  159. (1): PositionalEncoding(
  160. (dropout): Dropout(p=0.1)
  161. )
  162. )
  163. (tgt_embed): Sequential(
  164. (0): Embeddings(
  165. (lut): Embedding(11, 512)
  166. )
  167. (1): PositionalEncoding(
  168. (dropout): Dropout(p=0.1)
  169. )
  170. )
  171. (generator): Generator(
  172. (proj): Linear(in_features=512, out_features=11)
  173. )
  174. )

小节总结

  • 学习并实现了编码器-解码器结构的类: EncoderDecoder

    • 类的初始化函数传入5个参数, 分别是编码器对象, 解码器对象, 源数据嵌入函数, 目标数据嵌入函数, 以及输出部分的类别生成器对象.
    • 类中共实现三个函数, forward, encode, decode
    • forward是主要逻辑函数, 有四个参数, source代表源数据, target代表目标数据, source_mask和target_mask代表对应的掩码张量.
    • encode是编码函数, 以source和source_mask为参数.
    • decode是解码函数, 以memory即编码器的输出, source_mask, target, target_mask为参数

  • 学习并实现了模型构建函数: make_model

    • 有7个参数,分别是源数据特征(词汇)总数,目标数据特征(词汇)总数,编码器和解码器堆叠数,词向量映射维度,前馈全连接网络中变换矩阵的维度,多头注意力结构中的多头数,以及置零比率dropout.
    • 该函数最后返回一个构建好的模型对象.
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/838706
推荐阅读
相关标签
  

闽ICP备14008679号