当前位置:   article > 正文

使用pytorch实现Transformer(带有注释)_transformer pytorch 代码

transformer pytorch 代码


Abstract

在这里插入图片描述

1.导入库

import copy
import torch
import torch.nn.functional as F
from torch import nn
import math
  • 1
  • 2
  • 3
  • 4
  • 5

2.模型架构

在这里插入图片描述

1.1 PositionalEncoding

class PositionalEncoding(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device

    def forward(self, X):
        # X的形状为:batch_size,num_steps,embedding_size
        batch_size = X.shape[0]
        num_steps = X.shape[1]
        embedding_size = X.shape[2]
        position = torch.zeros(num_steps, embedding_size, device=self.device)
        value = torch.arange(num_steps, device=self.device).repeat(embedding_size, 1).permute(1, 0) / torch.pow(10000,torch.arange(embedding_size,device=self.device) / embedding_size).repeat(num_steps, 1)                                                                                                                                                                                                                                                                                                                                     
        position[:, 0::2] = torch.sin(value[:, 0::2])
        position[:, 1::2] = torch.cos(value[:, 1::2])
        return value.repeat(batch_size, 1, 1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

1.2 Multi_Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, device):
        super().__init__()
        self.device = device
        self.num_heads = num_heads
        self.W_Q = nn.Linear(query_size, num_hiddens, bias=False)
        self.W_K = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_V = nn.Linear(value_size, num_hiddens, bias=False)
        self.W_O = nn.Linear(num_hiddens, num_hiddens, bias=False)

    def reform(self, X):
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0, 2, 1, 3)
        X = X.reshape(-1, X.shape[2], X.shape[3])
        return X

    def reform_back(self, X):
        # batch_size*num_heads,num_steps,num_hiddens/num_heads
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        X = X.reshape(X.shape[0], X.shape[1], -1)
        return X

    def attention(self, queries, keys, values, valid_len):
        keys_num_steps = keys.shape[1]
        queries_num_steps = queries.shape[1]
        # 上面的valid_len 的形状为 batch_size
        d = queries.shape[-1]
        # A的形状为batch_size*num_heads,queries_num_steps,keys_num_steps
        A = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        if valid_len is not None:
            # 如果不为None:
            # 情况一:是encoder的queries查询encoder的key_value计算attention,需要valid_len遮住encoder中的pad部分
            # 情况二:是decoder的queries查询encoder的key_value计算attention,需要valid_len遮住encoder中的pad部分
            valid_len = torch.repeat_interleave(valid_len, repeats=self.num_heads, dim=0)
            mask = torch.arange(1, keys_num_steps + 1, device=self.device)[None, None, :] > valid_len[:, None, None]
            mask = mask.repeat(1, queries_num_steps, 1)
            A[mask] = -1e6
        else:
            # 如果为None:说明是根据decoder的queries查询decoder的key_value计算attention,要遮掩住后面的部分
            # mask的形状为:queries_num_steps,keys_num_steps
            mask = torch.triu(torch.arange(keys_num_steps).repeat(queries_num_steps, 1), 1) > 0
            A[:, mask] = -1e6
        A_softmaxed = F.softmax(A, dim=-1)
        attention = torch.bmm(A_softmaxed, values)
        return attention

    def forward(self, queries, keys, values, valid_len):
        # queries,keys,values的形状为: batch_size,num_steps,embedding_size
        # Q,K,V的形状为:batch_size,num_steps,num_hiddens
        Q = self.W_Q(queries)
        K = self.W_K(keys)
        V = self.W_V(values)
        # 将Q,K,V的形状改为batch_size*num_heads,num_steps,num_hiddens/num_heads
        Q = self.reform(Q)
        K = self.reform(K)
        V = self.reform(V)
        # 计算Attention
        # attention的形状为batch_size*num_heads,num_steps,num_hiddens/num_heads
        attention = self.attention(Q, K, V, valid_len)
        # 将attention形状改为batch_size,num_steps,num_hiddens
        attention = self.reform_back(attention)
        return self.W_O(attention)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

1.3 Feed Forward

class FeedForward(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.linear1 = nn.Linear(embedding_size, 2048)
        self.linear2 = nn.Linear(2048, embedding_size)

    def forward(self, X):
        return self.linear2(F.relu(self.linear1(X)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.4 SubLayer

class SubLayer(nn.Module):
    def __init__(self, layer, embedding_size):
        super().__init__()
        self.layer = layer
        self.norm = nn.LayerNorm(embedding_size)

    def forward(self, queries, keys=None, values=None, valid_len=None):
        old_X = queries
        # 因为MultiHeadAttention和FeedForward的参数不一样
        if isinstance(self.layer, MultiHeadAttention):
            X = self.layer(queries, keys, values, valid_len)
        else:
            X = self.layer(queries)
        X = old_X + X
        return self.norm(X)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

1.5 EncoderBlock

class EncoderBlock(nn.Module):
    def __init__(self, embedding_size, num_heads, device):
        super().__init__()
        self.device = device
        query_size = key_size = value_size = num_hiddens = embedding_size
        # subLayer的实例
        multiHeadAttention = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, self.device)
        feedForward = FeedForward(embedding_size)
        self.subLayer1 = SubLayer(multiHeadAttention, embedding_size)
        self.subLayer2 = SubLayer(feedForward, embedding_size)
        
    def forward(self, X, valid_len):
        # 进行self-MultiHeadAttention
        X = self.subLayer1(X, X, X, valid_len)
        # FeedForward
        X = self.subLayer2(X)
        return X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

1.6 DecoderBlock

class DecoderBlock(nn.Module):
    def __init__(self, embedding_size, num_heads, i, device):
        super().__init__()
        self.device = device
        query_size = key_size = value_size = num_hiddens = embedding_size
        # subLayer的实例
        multiHeadAttention1 = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, self.device)
        multiHeadAttention2 = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, self.device)
        feedForward = FeedForward(embedding_size)
        self.subLayer1 = SubLayer(multiHeadAttention1, embedding_size)
        self.subLayer2 = SubLayer(multiHeadAttention2, embedding_size)
        self.subLayer3 = SubLayer(feedForward, embedding_size)
        # 表示当前是第几个block
        self.i = i
        # 在预计阶段,front记录当前时刻前的key_value
        self.front = None

    def forward(self, encoder_output, encoder_valid_len, X):
        # 如果是训练,是一次性将句子放进来的;如果是预测,是一个词接着一个词输入的,在self-attention中需要将前面的词也当作key和value
        if self.training:
            key_values = X
        else:
            key_values = torch.cat([self.front, X], dim=1)
            self.front = key_values
        # 进行self-MultiHeadAttention,不传decoder_valid_len,自动遮掩住当前时刻后面的部分
        X = self.subLayer1(X, key_values, key_values)
        # 与encoder进行MultiHeadAttention,需要encoder_valid_len来遮住pad部分
        X = self.subLayer2(X, encoder_output, encoder_output, encoder_valid_len)
        # FeedForward
        X = self.subLayer3(X)
        return X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

1.7 Encoder

class Encoder(nn.Module):
    def __init__(self, encoder_vocab_size, embedding_size, num_layers, num_heads, device):
        self.device = device
        super().__init__()
        # encoder的层数
        self.num_layers = num_layers
        # 词嵌入层
        self.embeddingLayer = nn.Embedding(encoder_vocab_size, embedding_size)
        # 位置嵌入层
        self.positionalEncodingLayer = PositionalEncoding(device)
        # encoder层
        self.encoderLayers = nn.ModuleList(
            [copy.deepcopy(EncoderBlock(embedding_size, num_heads, device)) for _ in range(num_layers)])
        self.embedding_size = embedding_size

    def forward(self, source, encoder_valid_len):
        # 词嵌入
        X = self.embeddingLayer(source) * math.sqrt(self.embedding_size)
        # 位置嵌入
        positionalembedding = self.positionalEncodingLayer(X)
        X = X + positionalembedding
        for i in range(self.num_layers):
            X = self.encoderLayers[i](X, encoder_valid_len)
        return X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

1.8 Decoder

class Decoder(nn.Module):
    def __init__(self, decoder_vocab_size, embedding_size, num_layers, num_heads, device):
        super().__init__()
        self.device = device
        # decoder的层数
        self.num_layers = num_layers
        # 词嵌入层
        self.embeddingLayer = nn.Embedding(decoder_vocab_size, embedding_size)
        # 位置嵌入层
        self.positionalEncodingLayer = PositionalEncoding(device=self.device)
        # decoder层
        self.decoderLayers = nn.ModuleList(
            [copy.deepcopy(DecoderBlock(embedding_size, num_heads, i, self.device)) for i in range(num_layers)])
        self.embedding_size = embedding_size

    def forward(self, encoder_output, encoder_valid_len, target):
        # 词嵌入
        X = self.embeddingLayer(target) * math.sqrt(self.embedding_size)
        # 位置嵌入
        positionalembedding = self.positionalEncodingLayer(X)
        X = X + positionalembedding
        for i in range(self.num_layers):
            X = self.decoderLayers[i](encoder_output, encoder_valid_len, X)
        return X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

1.9 EncoderDecoder

class EncoderDecoder(nn.Module):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, embedding_size, num_layers, num_heads, device):
        super().__init__()
        self.device = device
        self.encoder = Encoder(encoder_vocab_size, embedding_size, num_layers, num_heads, self.device)
        self.decoder = Decoder(decoder_vocab_size, embedding_size, num_layers, num_heads, self.device)
        # 用于分类
        self.dense = nn.Linear(embedding_size, decoder_vocab_size)

    def forward(self, source, encoder_valid_len, target):
        encoder_output = self.encoder(source, encoder_valid_len)
        decoder_output = self.decoder(encoder_output, encoder_valid_len, target)
        return self.dense(decoder_output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/528918
推荐阅读
相关标签
  

闽ICP备14008679号