当前位置:   article > 正文

Transformer 结构及其代码实现_transformer实现

transformer实现

一、Transformer 结构图

        如下图,为 Transformer 的整体结构图,左侧为 Transformer Encoder Block,右侧为 Transformer Decoder Block。

        在整体使用中,两个 Block 均被多次重复使用,即上一 Block 的输出向量作为下一 Block 的输入向量。

Transformer Architecture

二、代码实现

        由(一)所介绍,Transformer 是由 TransformerEncoder 和 TransformerDecoder 组成,而这两者又分别是由多个 TransformerEncoderLayers 和 TransformerDecoderLayers 组成(或理解为多个 Block 组成)

        下图代码,建议对照上图内部结构去看,更容易理解一些。

        1.1)TransformerEncoderLayer 代码:

  1. # Transformer Encoder Layer
  2. class TransformerEncoderLayer(Module):
  3. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  4. This standard encoder layer is based on the paper "Attention Is All You Need".
  5. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  6. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  7. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  8. in a different way during application.
  9. Args:
  10. d_model: the number of expected features in the input (required).
  11. nhead: the number of heads in the multiheadattention models (required).
  12. dim_feedforward: the dimension of the feedforward network model (default=2048).
  13. dropout: the dropout value (default=0.1).
  14. activation: the activation function of intermediate layer, relu or gelu (default=relu).
  15. Examples::
  16. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  17. >>> src = torch.rand(10, 32, 512)
  18. >>> out = encoder_layer(src)
  19. """
  20. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
  21. super(TransformerEncoderLayer, self).__init__()
  22. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  23. # Implementation of Feedforward model
  24. self.linear1 = Linear(d_model, dim_feedforward)
  25. self.dropout = Dropout(dropout)
  26. self.linear2 = Linear(dim_feedforward, d_model)
  27. self.norm1 = LayerNorm(d_model)
  28. self.norm2 = LayerNorm(d_model)
  29. self.dropout1 = Dropout(dropout)
  30. self.dropout2 = Dropout(dropout)
  31. self.activation = _get_activation_fn(activation)
  32. def __setstate__(self, state):
  33. if 'activation' not in state:
  34. state['activation'] = F.relu
  35. super(TransformerEncoderLayer, self).__setstate__(state)
  36. def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  37. r"""Pass the input through the encoder layer.
  38. Args:
  39. src: the sequence to the encoder layer (required).
  40. src_mask: the mask for the src sequence (optional).
  41. src_key_padding_mask: the mask for the src keys per batch (optional).
  42. Shape:
  43. see the docs in Transformer class.
  44. """
  45. # look the picture of transformer encoder
  46. # Norm(src+Dropout(self_attention(src)))
  47. src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  48. key_padding_mask=src_key_padding_mask)[0]
  49. src = src + self.dropout1(src2)
  50. src = self.norm1(src)
  51. # Norm(src+Dropout(Feedforward()))
  52. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  53. src = src + self.dropout2(src2)
  54. src = self.norm2(src)
  55. return src

        1.2)TransformerEncoder 代码(多次执行 TransformerEncoderLayer 里的内容):

  1. # A stack of N encoder layers
  2. class TransformerEncoder(Module):
  3. r"""TransformerEncoder is a stack of N encoder layers
  4. Args:
  5. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
  6. num_layers: the number of sub-encoder-layers in the encoder (required).
  7. norm: the layer normalization component (optional).
  8. Examples::
  9. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  10. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  11. >>> src = torch.rand(10, 32, 512)
  12. >>> out = transformer_encoder(src)
  13. """
  14. __constants__ = ['norm']
  15. def __init__(self, encoder_layer, num_layers, norm=None):
  16. super(TransformerEncoder, self).__init__()
  17. self.layers = _get_clones(encoder_layer, num_layers)
  18. self.num_layers = num_layers
  19. self.norm = norm
  20. def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  21. r"""Pass the input through the encoder layers in turn.
  22. Args:
  23. src: the sequence to the encoder (required).
  24. mask: the mask for the src sequence (optional).
  25. src_key_padding_mask: the mask for the src keys per batch (optional).
  26. Shape:
  27. see the docs in Transformer class.
  28. """
  29. output = src
  30. for mod in self.layers:
  31. output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  32. if self.norm is not None:
  33. output = self.norm(output)
  34. return output

        2.1)TransformerDecoderLayer 代码:

  1. # For language reconstruct
  2. class TransformerDecoderLayer(Module):
  3. r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
  4. This standard decoder layer is based on the paper "Attention Is All You Need".
  5. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  6. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  7. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  8. in a different way during application.
  9. Args:
  10. d_model: the number of expected features in the input (required).
  11. nhead: the number of heads in the multiheadattention models (required).
  12. dim_feedforward: the dimension of the feedforward network model (default=2048).
  13. dropout: the dropout value (default=0.1).
  14. activation: the activation function of intermediate layer, relu or gelu (default=relu).
  15. Examples::
  16. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  17. >>> memory = torch.rand(10, 32, 512)
  18. >>> tgt = torch.rand(20, 32, 512)
  19. >>> out = decoder_layer(tgt, memory)
  20. """
  21. # d_model = 768, nhead = 8
  22. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
  23. super(TransformerDecoderLayer, self).__init__()
  24. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  25. self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  26. # Implementation of Feedforward model
  27. self.linear1 = Linear(d_model, dim_feedforward)
  28. self.dropout = Dropout(dropout)
  29. self.linear2 = Linear(dim_feedforward, d_model)
  30. self.norm1 = LayerNorm(d_model)
  31. self.norm2 = LayerNorm(d_model)
  32. self.norm3 = LayerNorm(d_model)
  33. self.dropout1 = Dropout(dropout)
  34. self.dropout2 = Dropout(dropout)
  35. self.dropout3 = Dropout(dropout)
  36. self.activation = _get_activation_fn(activation)
  37. def __setstate__(self, state):
  38. if 'activation' not in state:
  39. state['activation'] = F.relu
  40. super(TransformerDecoderLayer, self).__setstate__(state)
  41. # tgt: the sequence to the decoder layer (required). (20,1,768)
  42. # memory: the sequence from the last layer of the encoder (required). (3600,1,768)
  43. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
  44. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  45. r"""Pass the inputs (and mask) through the decoder layer.
  46. Args:
  47. tgt: the sequence to the decoder layer (required).
  48. memory: the sequence from the last layer of the encoder (required).
  49. tgt_mask: the mask for the tgt sequence (optional).
  50. memory_mask: the mask for the memory sequence (optional).
  51. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  52. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  53. Shape:
  54. see the docs in Transformer class.
  55. """
  56. # 类比 Transformer Decoder 的结构
  57. # tgt = Norm(Dropout(attention(tgt,tgt,tgt))+tgt)
  58. tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, # Multihead self-attention tgt2 (20,1,768)
  59. key_padding_mask=tgt_key_padding_mask)[0]
  60. tgt = tgt + self.dropout1(tgt2) # tgt = tgt + dropout1(0.1,tgt2) (20,1,768)
  61. tgt = self.norm1(tgt) # LayerNorm (1,768,60,60)
  62. # tgt = Norm(Dropout(attention(tgt,memory,memory))+tgt)
  63. tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, # Multihead self-attention tgt2 (20,1,768)
  64. key_padding_mask=memory_key_padding_mask)[0]
  65. tgt = tgt + self.dropout2(tgt2) # (20,1,768)
  66. tgt = self.norm2(tgt) # (20,1,768)
  67. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) # tgt2 (20,1,768)
  68. tgt = tgt + self.dropout3(tgt2)
  69. tgt = self.norm3(tgt)
  70. return tgt # (20,1,768)

      2.2)TransformerDecoder 代码(多次执行 TransformerDecoderLayer 里的内容):

  1. # A stack of N decoder layers
  2. class TransformerDecoder(Module):
  3. r"""TransformerDecoder is a stack of N decoder layers
  4. Args:
  5. decoder_layer: an instance of the TransformerDecoderLayer() class (required).
  6. num_layers: the number of sub-decoder-layers in the decoder (required).
  7. norm: the layer normalization component (optional).
  8. Examples::
  9. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  10. >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
  11. >>> memory = torch.rand(10, 32, 512)
  12. >>> tgt = torch.rand(20, 32, 512)
  13. >>> out = transformer_decoder(tgt, memory)
  14. """
  15. __constants__ = ['norm']
  16. def __init__(self, decoder_layer, num_layers, norm=None):
  17. super(TransformerDecoder, self).__init__()
  18. self.layers = _get_clones(decoder_layer, num_layers)
  19. self.num_layers = num_layers
  20. self.norm = norm
  21. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  22. memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
  23. memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  24. r"""Pass the inputs (and mask) through the decoder layer in turn.
  25. Args:
  26. tgt: the sequence to the decoder (required).
  27. memory: the sequence from the last layer of the encoder (required).
  28. tgt_mask: the mask for the tgt sequence (optional).
  29. memory_mask: the mask for the memory sequence (optional).
  30. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  31. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  32. Shape:
  33. see the docs in Transformer class.
  34. """
  35. output = tgt
  36. for mod in self.layers:
  37. output = mod(output, memory, tgt_mask=tgt_mask,
  38. memory_mask=memory_mask,
  39. tgt_key_padding_mask=tgt_key_padding_mask,
  40. memory_key_padding_mask=memory_key_padding_mask)
  41. if self.norm is not None:
  42. output = self.norm(output)
  43. return output

          2.3)Transformer 代码:

                a、先初始化 TransformerEncoder 和 TransformerDecoder

                b、在 forward( ) 中分别调用他们

  1. # High Architecture of Transformer encoder and Transformer decoder
  2. class Transformer(Module):
  3. r"""A transformer model. User is able to modify the attributes as needed. The architecture
  4. is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
  5. Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
  6. Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
  7. Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
  8. model with corresponding parameters.
  9. Args:
  10. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  11. nhead: the number of heads in the multiheadattention models (default=8).
  12. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  13. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  14. dim_feedforward: the dimension of the feedforward network model (default=2048).
  15. dropout: the dropout value (default=0.1).
  16. activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
  17. custom_encoder: custom encoder (default=None).
  18. custom_decoder: custom decoder (default=None).
  19. Examples::
  20. >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
  21. >>> src = torch.rand((10, 32, 512))
  22. >>> tgt = torch.rand((20, 32, 512))
  23. >>> out = transformer_model(src, tgt)
  24. Note: A full example to apply nn.Transformer module for the word language model is available in
  25. https://github.com/pytorch/examples/tree/master/word_language_model
  26. """
  27. def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
  28. num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
  29. activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None) -> None:
  30. super(Transformer, self).__init__()
  31. if custom_encoder is not None:
  32. self.encoder = custom_encoder
  33. else:
  34. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
  35. encoder_norm = LayerNorm(d_model)
  36. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
  37. if custom_decoder is not None:
  38. self.decoder = custom_decoder
  39. else:
  40. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
  41. decoder_norm = LayerNorm(d_model)
  42. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
  43. self._reset_parameters()
  44. self.d_model = d_model
  45. self.nhead = nhead
  46. def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
  47. memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
  48. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  49. r"""Take in and process masked source/target sequences.
  50. Args:
  51. src: the sequence to the encoder (required).
  52. tgt: the sequence to the decoder (required).
  53. src_mask: the additive mask for the src sequence (optional).
  54. tgt_mask: the additive mask for the tgt sequence (optional).
  55. memory_mask: the additive mask for the encoder output (optional).
  56. src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
  57. tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
  58. memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
  59. Shape:
  60. - src: :math:`(S, N, E)`.
  61. - tgt: :math:`(T, N, E)`.
  62. - src_mask: :math:`(S, S)`.
  63. - tgt_mask: :math:`(T, T)`.
  64. - memory_mask: :math:`(T, S)`.
  65. - src_key_padding_mask: :math:`(N, S)`.
  66. - tgt_key_padding_mask: :math:`(N, T)`.
  67. - memory_key_padding_mask: :math:`(N, S)`.
  68. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
  69. positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
  70. while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
  71. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  72. is provided, it will be added to the attention weight.
  73. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
  74. the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
  75. positions will be unchanged. If a BoolTensor is provided, the positions with the
  76. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  77. - output: :math:`(T, N, E)`.
  78. Note: Due to the multi-head attention architecture in the transformer model,
  79. the output sequence length of a transformer is same as the input sequence
  80. (i.e. target) length of the decode.
  81. where S is the source sequence length, T is the target sequence length, N is the
  82. batch size, E is the feature number
  83. Examples:
  84. >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
  85. """
  86. if src.size(1) != tgt.size(1):
  87. raise RuntimeError("the batch number of src and tgt must be equal")
  88. if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
  89. raise RuntimeError("the feature number of src and tgt must be equal to d_model")
  90. memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
  91. output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
  92. tgt_key_padding_mask=tgt_key_padding_mask,
  93. memory_key_padding_mask=memory_key_padding_mask)
  94. return output
  95. def generate_square_subsequent_mask(self, sz: int) -> Tensor:
  96. r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  97. Unmasked positions are filled with float(0.0).
  98. """
  99. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  100. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  101. return mask
  102. def _reset_parameters(self):
  103. r"""Initiate parameters in the transformer model."""
  104. for p in self.parameters():
  105. if p.dim() > 1:
  106. xavier_uniform_(p)

         下面是完整的 transformer.py 文件:

  1. import copy
  2. from typing import Optional, Any
  3. import torch
  4. from torch import Tensor
  5. from .. import functional as F
  6. from .module import Module
  7. from .activation import MultiheadAttention
  8. from .container import ModuleList
  9. from ..init import xavier_uniform_
  10. from .dropout import Dropout
  11. from .linear import Linear
  12. from .normalization import LayerNorm
  13. # High Architecture of Transformer encoder and Transformer decoder
  14. class Transformer(Module):
  15. r"""A transformer model. User is able to modify the attributes as needed. The architecture
  16. is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
  17. Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
  18. Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
  19. Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
  20. model with corresponding parameters.
  21. Args:
  22. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  23. nhead: the number of heads in the multiheadattention models (default=8).
  24. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  25. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  26. dim_feedforward: the dimension of the feedforward network model (default=2048).
  27. dropout: the dropout value (default=0.1).
  28. activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
  29. custom_encoder: custom encoder (default=None).
  30. custom_decoder: custom decoder (default=None).
  31. Examples::
  32. >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
  33. >>> src = torch.rand((10, 32, 512))
  34. >>> tgt = torch.rand((20, 32, 512))
  35. >>> out = transformer_model(src, tgt)
  36. Note: A full example to apply nn.Transformer module for the word language model is available in
  37. https://github.com/pytorch/examples/tree/master/word_language_model
  38. """
  39. def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
  40. num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
  41. activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None) -> None:
  42. super(Transformer, self).__init__()
  43. if custom_encoder is not None:
  44. self.encoder = custom_encoder
  45. else:
  46. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
  47. encoder_norm = LayerNorm(d_model)
  48. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
  49. if custom_decoder is not None:
  50. self.decoder = custom_decoder
  51. else:
  52. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
  53. decoder_norm = LayerNorm(d_model)
  54. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
  55. self._reset_parameters()
  56. self.d_model = d_model
  57. self.nhead = nhead
  58. def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
  59. memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
  60. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  61. r"""Take in and process masked source/target sequences.
  62. Args:
  63. src: the sequence to the encoder (required).
  64. tgt: the sequence to the decoder (required).
  65. src_mask: the additive mask for the src sequence (optional).
  66. tgt_mask: the additive mask for the tgt sequence (optional).
  67. memory_mask: the additive mask for the encoder output (optional).
  68. src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
  69. tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
  70. memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
  71. Shape:
  72. - src: :math:`(S, N, E)`.
  73. - tgt: :math:`(T, N, E)`.
  74. - src_mask: :math:`(S, S)`.
  75. - tgt_mask: :math:`(T, T)`.
  76. - memory_mask: :math:`(T, S)`.
  77. - src_key_padding_mask: :math:`(N, S)`.
  78. - tgt_key_padding_mask: :math:`(N, T)`.
  79. - memory_key_padding_mask: :math:`(N, S)`.
  80. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
  81. positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
  82. while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
  83. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  84. is provided, it will be added to the attention weight.
  85. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
  86. the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
  87. positions will be unchanged. If a BoolTensor is provided, the positions with the
  88. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  89. - output: :math:`(T, N, E)`.
  90. Note: Due to the multi-head attention architecture in the transformer model,
  91. the output sequence length of a transformer is same as the input sequence
  92. (i.e. target) length of the decode.
  93. where S is the source sequence length, T is the target sequence length, N is the
  94. batch size, E is the feature number
  95. Examples:
  96. >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
  97. """
  98. if src.size(1) != tgt.size(1):
  99. raise RuntimeError("the batch number of src and tgt must be equal")
  100. if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
  101. raise RuntimeError("the feature number of src and tgt must be equal to d_model")
  102. memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
  103. output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
  104. tgt_key_padding_mask=tgt_key_padding_mask,
  105. memory_key_padding_mask=memory_key_padding_mask)
  106. return output
  107. def generate_square_subsequent_mask(self, sz: int) -> Tensor:
  108. r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  109. Unmasked positions are filled with float(0.0).
  110. """
  111. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  112. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  113. return mask
  114. def _reset_parameters(self):
  115. r"""Initiate parameters in the transformer model."""
  116. for p in self.parameters():
  117. if p.dim() > 1:
  118. xavier_uniform_(p)
  119. # A stack of N encoder layers
  120. class TransformerEncoder(Module):
  121. r"""TransformerEncoder is a stack of N encoder layers
  122. Args:
  123. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
  124. num_layers: the number of sub-encoder-layers in the encoder (required).
  125. norm: the layer normalization component (optional).
  126. Examples::
  127. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  128. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  129. >>> src = torch.rand(10, 32, 512)
  130. >>> out = transformer_encoder(src)
  131. """
  132. __constants__ = ['norm']
  133. def __init__(self, encoder_layer, num_layers, norm=None):
  134. super(TransformerEncoder, self).__init__()
  135. self.layers = _get_clones(encoder_layer, num_layers)
  136. self.num_layers = num_layers
  137. self.norm = norm
  138. def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  139. r"""Pass the input through the encoder layers in turn.
  140. Args:
  141. src: the sequence to the encoder (required).
  142. mask: the mask for the src sequence (optional).
  143. src_key_padding_mask: the mask for the src keys per batch (optional).
  144. Shape:
  145. see the docs in Transformer class.
  146. """
  147. output = src
  148. for mod in self.layers:
  149. output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  150. if self.norm is not None:
  151. output = self.norm(output)
  152. return output
  153. # A stack of N decoder layers
  154. class TransformerDecoder(Module):
  155. r"""TransformerDecoder is a stack of N decoder layers
  156. Args:
  157. decoder_layer: an instance of the TransformerDecoderLayer() class (required).
  158. num_layers: the number of sub-decoder-layers in the decoder (required).
  159. norm: the layer normalization component (optional).
  160. Examples::
  161. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  162. >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
  163. >>> memory = torch.rand(10, 32, 512)
  164. >>> tgt = torch.rand(20, 32, 512)
  165. >>> out = transformer_decoder(tgt, memory)
  166. """
  167. __constants__ = ['norm']
  168. def __init__(self, decoder_layer, num_layers, norm=None):
  169. super(TransformerDecoder, self).__init__()
  170. self.layers = _get_clones(decoder_layer, num_layers)
  171. self.num_layers = num_layers
  172. self.norm = norm
  173. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  174. memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
  175. memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  176. r"""Pass the inputs (and mask) through the decoder layer in turn.
  177. Args:
  178. tgt: the sequence to the decoder (required).
  179. memory: the sequence from the last layer of the encoder (required).
  180. tgt_mask: the mask for the tgt sequence (optional).
  181. memory_mask: the mask for the memory sequence (optional).
  182. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  183. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  184. Shape:
  185. see the docs in Transformer class.
  186. """
  187. output = tgt
  188. for mod in self.layers:
  189. output = mod(output, memory, tgt_mask=tgt_mask,
  190. memory_mask=memory_mask,
  191. tgt_key_padding_mask=tgt_key_padding_mask,
  192. memory_key_padding_mask=memory_key_padding_mask)
  193. if self.norm is not None:
  194. output = self.norm(output)
  195. return output
  196. # Transformer Encoder Layer
  197. class TransformerEncoderLayer(Module):
  198. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  199. This standard encoder layer is based on the paper "Attention Is All You Need".
  200. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  201. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  202. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  203. in a different way during application.
  204. Args:
  205. d_model: the number of expected features in the input (required).
  206. nhead: the number of heads in the multiheadattention models (required).
  207. dim_feedforward: the dimension of the feedforward network model (default=2048).
  208. dropout: the dropout value (default=0.1).
  209. activation: the activation function of intermediate layer, relu or gelu (default=relu).
  210. Examples::
  211. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  212. >>> src = torch.rand(10, 32, 512)
  213. >>> out = encoder_layer(src)
  214. """
  215. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
  216. super(TransformerEncoderLayer, self).__init__()
  217. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  218. # Implementation of Feedforward model
  219. self.linear1 = Linear(d_model, dim_feedforward)
  220. self.dropout = Dropout(dropout)
  221. self.linear2 = Linear(dim_feedforward, d_model)
  222. self.norm1 = LayerNorm(d_model)
  223. self.norm2 = LayerNorm(d_model)
  224. self.dropout1 = Dropout(dropout)
  225. self.dropout2 = Dropout(dropout)
  226. self.activation = _get_activation_fn(activation)
  227. def __setstate__(self, state):
  228. if 'activation' not in state:
  229. state['activation'] = F.relu
  230. super(TransformerEncoderLayer, self).__setstate__(state)
  231. def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  232. r"""Pass the input through the encoder layer.
  233. Args:
  234. src: the sequence to the encoder layer (required).
  235. src_mask: the mask for the src sequence (optional).
  236. src_key_padding_mask: the mask for the src keys per batch (optional).
  237. Shape:
  238. see the docs in Transformer class.
  239. """
  240. # look the picture of transformer encoder
  241. # Norm(src+Dropout(self_attention(src)))
  242. src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  243. key_padding_mask=src_key_padding_mask)[0]
  244. src = src + self.dropout1(src2)
  245. src = self.norm1(src)
  246. # Norm(src+Dropout(Feedforward()))
  247. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  248. src = src + self.dropout2(src2)
  249. src = self.norm2(src)
  250. return src
  251. # For language reconstruct
  252. class TransformerDecoderLayer(Module):
  253. r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
  254. This standard decoder layer is based on the paper "Attention Is All You Need".
  255. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  256. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  257. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  258. in a different way during application.
  259. Args:
  260. d_model: the number of expected features in the input (required).
  261. nhead: the number of heads in the multiheadattention models (required).
  262. dim_feedforward: the dimension of the feedforward network model (default=2048).
  263. dropout: the dropout value (default=0.1).
  264. activation: the activation function of intermediate layer, relu or gelu (default=relu).
  265. Examples::
  266. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  267. >>> memory = torch.rand(10, 32, 512)
  268. >>> tgt = torch.rand(20, 32, 512)
  269. >>> out = decoder_layer(tgt, memory)
  270. """
  271. # d_model = 768, nhead = 8
  272. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
  273. super(TransformerDecoderLayer, self).__init__()
  274. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  275. self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
  276. # Implementation of Feedforward model
  277. self.linear1 = Linear(d_model, dim_feedforward)
  278. self.dropout = Dropout(dropout)
  279. self.linear2 = Linear(dim_feedforward, d_model)
  280. self.norm1 = LayerNorm(d_model)
  281. self.norm2 = LayerNorm(d_model)
  282. self.norm3 = LayerNorm(d_model)
  283. self.dropout1 = Dropout(dropout)
  284. self.dropout2 = Dropout(dropout)
  285. self.dropout3 = Dropout(dropout)
  286. self.activation = _get_activation_fn(activation)
  287. def __setstate__(self, state):
  288. if 'activation' not in state:
  289. state['activation'] = F.relu
  290. super(TransformerDecoderLayer, self).__setstate__(state)
  291. # tgt: the sequence to the decoder layer (required). (20,1,768)
  292. # memory: the sequence from the last layer of the encoder (required). (3600,1,768)
  293. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
  294. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  295. r"""Pass the inputs (and mask) through the decoder layer.
  296. Args:
  297. tgt: the sequence to the decoder layer (required).
  298. memory: the sequence from the last layer of the encoder (required).
  299. tgt_mask: the mask for the tgt sequence (optional).
  300. memory_mask: the mask for the memory sequence (optional).
  301. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  302. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  303. Shape:
  304. see the docs in Transformer class.
  305. """
  306. # 类比 Transformer Decoder 的结构
  307. # tgt = Norm(Dropout(attention(tgt,tgt,tgt))+tgt)
  308. tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, # Multihead self-attention tgt2 (20,1,768)
  309. key_padding_mask=tgt_key_padding_mask)[0]
  310. tgt = tgt + self.dropout1(tgt2) # tgt = tgt + dropout1(0.1,tgt2) (20,1,768)
  311. tgt = self.norm1(tgt) # LayerNorm (1,768,60,60)
  312. # tgt = Norm(Dropout(attention(tgt,memory,memory))+tgt)
  313. tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, # Multihead self-attention tgt2 (20,1,768)
  314. key_padding_mask=memory_key_padding_mask)[0]
  315. tgt = tgt + self.dropout2(tgt2) # (20,1,768)
  316. tgt = self.norm2(tgt) # (20,1,768)
  317. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) # tgt2 (20,1,768)
  318. tgt = tgt + self.dropout3(tgt2)
  319. tgt = self.norm3(tgt)
  320. return tgt # (20,1,768)
  321. def _get_clones(module, N):
  322. return ModuleList([copy.deepcopy(module) for i in range(N)])
  323. def _get_activation_fn(activation):
  324. if activation == "relu":
  325. return F.relu
  326. elif activation == "gelu":
  327. return F.gelu
  328. raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

 三、Transformer 中的多头注意力机制

        Transformer 中多次使用了多头注意力机制。

        在 EncoderLayer 中,使用了一次多头自注意力机制。

        在 DecoderLayer 中,先使用了一次多头自注意力机制,紧接着使用了一次多头非自注意力机制(k 为 tgt,q、v 为memory,是从上一个 encoder block 输出的结果)

        注意力机制的代码实现如下:

  1. # 多头注意力机制
  2. class MultiheadAttention(Module):
  3. r"""Allows the model to jointly attend to information
  4. from different representation subspaces.
  5. See reference: Attention Is All You Need
  6. .. math::
  7. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  8. \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
  9. Args:
  10. embed_dim: total dimension of the model.
  11. num_heads: parallel attention heads.
  12. dropout: a Dropout layer on attn_output_weights. Default: 0.0.
  13. bias: add bias as module parameter. Default: True.
  14. add_bias_kv: add bias to the key and value sequences at dim=0.
  15. add_zero_attn: add a new batch of zeros to the key and
  16. value sequences at dim=1.
  17. kdim: total number of features in key. Default: None.
  18. vdim: total number of features in value. Default: None.
  19. Note: if kdim and vdim are None, they will be set to embed_dim such that
  20. query, key, and value have the same number of features.
  21. Examples::
  22. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  23. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  24. """
  25. bias_k: Optional[torch.Tensor]
  26. bias_v: Optional[torch.Tensor]
  27. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
  28. super(MultiheadAttention, self).__init__()
  29. self.embed_dim = embed_dim
  30. self.kdim = kdim if kdim is not None else embed_dim
  31. self.vdim = vdim if vdim is not None else embed_dim
  32. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  33. self.num_heads = num_heads
  34. self.dropout = dropout
  35. self.head_dim = embed_dim // num_heads
  36. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  37. if self._qkv_same_embed_dim is False:
  38. self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
  39. self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
  40. self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
  41. self.register_parameter('in_proj_weight', None)
  42. else:
  43. self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
  44. self.register_parameter('q_proj_weight', None)
  45. self.register_parameter('k_proj_weight', None)
  46. self.register_parameter('v_proj_weight', None)
  47. if bias:
  48. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
  49. else:
  50. self.register_parameter('in_proj_bias', None)
  51. self.out_proj = _LinearWithBias(embed_dim, embed_dim)
  52. if add_bias_kv:
  53. self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
  54. self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
  55. else:
  56. self.bias_k = self.bias_v = None
  57. self.add_zero_attn = add_zero_attn
  58. self._reset_parameters()
  59. def _reset_parameters(self):
  60. if self._qkv_same_embed_dim:
  61. xavier_uniform_(self.in_proj_weight)
  62. else:
  63. xavier_uniform_(self.q_proj_weight)
  64. xavier_uniform_(self.k_proj_weight)
  65. xavier_uniform_(self.v_proj_weight)
  66. if self.in_proj_bias is not None:
  67. constant_(self.in_proj_bias, 0.)
  68. constant_(self.out_proj.bias, 0.)
  69. if self.bias_k is not None:
  70. xavier_normal_(self.bias_k)
  71. if self.bias_v is not None:
  72. xavier_normal_(self.bias_v)
  73. def __setstate__(self, state):
  74. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  75. if '_qkv_same_embed_dim' not in state:
  76. state['_qkv_same_embed_dim'] = True
  77. super(MultiheadAttention, self).__setstate__(state)
  78. def forward(self, query, key, value, key_padding_mask=None,
  79. need_weights=True, attn_mask=None):
  80. # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
  81. r"""
  82. Args:
  83. query, key, value: map a query and a set of key-value pairs to an output.
  84. See "Attention Is All You Need" for more details.
  85. key_padding_mask: if provided, specified padding elements in the key will
  86. be ignored by the attention. When given a binary mask and a value is True,
  87. the corresponding value on the attention layer will be ignored. When given
  88. a byte mask and a value is non-zero, the corresponding value on the attention
  89. layer will be ignored
  90. need_weights: output attn_output_weights.
  91. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
  92. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
  93. Shape:
  94. - Inputs:
  95. - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
  96. the embedding dimension.
  97. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
  98. the embedding dimension.
  99. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
  100. the embedding dimension.
  101. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
  102. If a ByteTensor is provided, the non-zero positions will be ignored while the position
  103. with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
  104. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  105. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
  106. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
  107. S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
  108. positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
  109. while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
  110. is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  111. is provided, it will be added to the attention weight.
  112. - Outputs:
  113. - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
  114. E is the embedding dimension.
  115. - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
  116. L is the target sequence length, S is the source sequence length.
  117. """
  118. if not self._qkv_same_embed_dim:
  119. return F.multi_head_attention_forward(
  120. query, key, value, self.embed_dim, self.num_heads,
  121. self.in_proj_weight, self.in_proj_bias,
  122. self.bias_k, self.bias_v, self.add_zero_attn,
  123. self.dropout, self.out_proj.weight, self.out_proj.bias,
  124. training=self.training,
  125. key_padding_mask=key_padding_mask, need_weights=need_weights,
  126. attn_mask=attn_mask, use_separate_proj_weight=True,
  127. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  128. v_proj_weight=self.v_proj_weight)
  129. else:
  130. return F.multi_head_attention_forward(
  131. query, key, value, self.embed_dim, self.num_heads,
  132. self.in_proj_weight, self.in_proj_bias,
  133. self.bias_k, self.bias_v, self.add_zero_attn,
  134. self.dropout, self.out_proj.weight, self.out_proj.bias,
  135. training=self.training,
  136. key_padding_mask=key_padding_mask, need_weights=need_weights,
  137. attn_mask=attn_mask)

         此博客为个人学习笔记,如有错误,欢迎指正!感谢各位大佬!

        

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/392360
推荐阅读
相关标签
  

闽ICP备14008679号