当前位置:   article > 正文

Pytorch笔记-2_transformer dim feedforward

transformer dim feedforward

Sequence-to-Sequence modeling with nn.Transformer and torchtext

使用nn.Transformer模块训练sequence-to-sequence模型。

pytorch1.2版本之后包含了标准的transformer模块, 这个模块是基于paper《Attention is All You Need》。nn.Transformer模块完全依赖注意力机制,来构建从输入到输出的全局依赖关系。

nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', custom_encoder=None, custom_decoder=None)

A transformer model。可以根据需要设置参数。这个架构是基于paper《Attention is all you need》。

参数:

  • d_model: 输入到encoder/decoder中特征数量,即:embedding的维度,默认512。
  • nhead:the multiheadattention models的heads的数量,默认8。
  • num_encoder_layer: encoder中sub-encoder-layers的数量,默认6。
  • num_decoder_layers: decoder中sub-decoder-layers的数量,默认6。
  • dim_feedforward: the feedforward network model的维度,默认2048。
  • dropout: the dropout value,默认0.1。
  • activation: encoder/decoder intermediate layer的激活函数,relu/gelu,默认relu。
  • custom_encoder: 自定义encoder。
  • custom_decoder:自定义decoder。

Transformer_model.png
Encoder: encoder有N=6个完全相同的layer组合。每个layer有2个sub-layers。其中第一个是multi-head self-attention机制,第二个是全连接的前馈神经网络。这2个sub-layers都加入了 residual connection(残差连接)和layer normalization的处理,即:LayerNorm(x+Sublayer(x))。为了利用残差连接,在Transformer模型中所有的sub-layers和embedding layers,所产生的outputs的维度 d m o d e l d_{model} dmodel=512.

1. Define the model

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerModel(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)
        self.init_weights()
    
    def generate_square_subsequent_mask(self, sz):
        # triu函数,表示对角线和对角线以上元素保持不变,其他设为0.
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        # masked_fill函数,另元素取值为mask==0的位置,设置为'-inf',元素取值mask==1的位置设置为0.
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        # 返回一个tensor:
        """
        tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])
        """
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)   #   torch.Size([35, 20, 200])
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)    # torch.Size([35, 20, 28783])
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

PositionalEncoding modules 学习了sequence中tokens的相对和绝对位置信息。positional encoding和the embedding有相同的维度,以便于两者能够相加。

P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 ( 2 i / d ) ) P E ( p o s , 2 i + 1 ) = c o n ( p o s / 1000 0 ( 2 i / d ) ) PE_{(pos, 2i)} = sin(pos/10000^{(2i/d)})\\ PE_{(pos, 2i+1)} = con(pos/10000^{(2i/d)}) PE(pos,2i)=sin(pos/10000(2i/d))PE(pos,2i+1)=con(pos/10000(2i/d))

import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # 构建 position encoding 矩阵,维度是(max_len, d_model),
        # 其中max_len是假设sequence(句子)长度最大是max_len=5000,sequence中token的维度d_model=512
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0)/d_model)
                            )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # 向模块添加持久缓冲区
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

2. Load and batch data

vocab 是根据训练数据集生成的,

batchify函数的作用,如图所示:
abcd.png

import io
import torch
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
# 根据URL地址下载并解压提取出文件,生成三个文件:wiki.test.tokens、wiki.train.tokens、wiki.valid.tokens
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url))
# 预处理sequence,去除特殊字符,并split等操作,生成tokens。
tokenizer = get_tokenizer('basic_english')
# 基于训练集构建vocab,即具有 {'I': 1, ... 'am': 1732, ...}的功能
vocab = build_vocab_from_iterator(map(tokenizer,
                                     iter(io.open(train_filepath,
                                                 encoding='utf-8'))))

def data_process(raw_text_iter):
    """
    raw_text_iter: 文本集中的每行一个迭代
    tokenizer: 对sequence(句子)进行数据预处理,比如去除特殊字符和空格,并split生成类似['I', 'love', 'pytorch', ...]的迭代器
    vocab: vocab是一个字典,每个token都有唯一的数字id,从而对sequence转换为数字序列。
    返回值:先过滤掉元素个数为0的数字序列,然后用cat拼接在一起,生成类似tensor([  10, 3850, 3870,  ..., 2443]形式。
    """
    data = [torch.tensor([vocab[token] for token in tokenizer(item)],
                        dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel()>0, data)))

train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    """
    若data的长度是26,batch_size=4,那么我们将分为4个sequence,每个sequence长度为6。
    """
    # 每个batch
    nbatch = data.size(0) // bsz
    # 函数narrow参数:(dim, start, length),dim是narrow要操作的维度,start,length表示在该维度上进行切片[start: start+length]
    data = data.narrow(0, 0, nbatch*bsz)
    # contiguous 函数一般与transpose,permute,view搭配使用,使得tensor内存中连续存储
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
print(train_data.size())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
36718lines [00:01, 25799.69lines/s]


torch.Size([102499, 20])
  • 1
  • 2
  • 3
  • 4

3. Functions to generate input and target sequence

get_batch()函数是构建transformer model的input和target,然后source的数据样本长度length设定为bptt。继续以序列[A、B、C…X、Y、Z]举例子,设定bptt=2,其batch_size还是等于4.
target.png

bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i: i+seq_len]   #   torch.Size([35, 20])
    target = source[i+1: i+1+seq_len].reshape(-1)    #   torch.Size([700])
    return data, target
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

4. Initiate an instance

超参数设置如下所示,vocab的大小等于vocab object的长度。

ntokens = len(vocab.stoi)    # the size of vocabulary
emsize = 200         # embedding dimension
nhid = 200           # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2          # the number of nn.TransforerEncoderLayer in nn.TransformerEncoder
nhead = 2            # the number of heads in the multiheadattention models
dropout = 0.2        # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

5. Run the model

# loss object
criterion = nn.CrossEntropyLoss()
# learning rate
lr = 5.0 
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# 初始lr设定5.0,根据epochs的变化,利用StepLR自动调节lr
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train():
    # turn on the train mode
    model.train()
    total_loss = 0.0
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        # print(output.size())   #  torch.Size([35, 20, 28783])
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        # 梯度剪枝,防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval()        # turn on the evaluation mode
    total_loss = 0.0
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source)-1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

loop over epochs,并保存模型直到validation loss 的效果最好。根据each epoch自动调整learning rate。

best_val_loss = float("inf")
epochs = 3
best_model = None

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-'*89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-'*89)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model
    # Adjust the learning rate after each epoch
    scheduler.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 12.63 | loss  5.48 | ppl   240.28
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 12.38 | loss  5.51 | ppl   246.10
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 12.42 | loss  5.31 | ppl   202.03
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 12.63 | loss  5.38 | ppl   216.16
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 12.61 | loss  5.35 | ppl   209.95
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 12.59 | loss  5.39 | ppl   218.51
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 12.48 | loss  5.41 | ppl   223.10
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 12.50 | loss  5.44 | ppl   230.27
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 14.89 | loss  5.39 | ppl   218.99
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 13.27 | loss  5.42 | ppl   226.92
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 12.62 | loss  5.27 | ppl   195.07
| epoch   1 |  2400/ 2928 batches | lr 5.00 | ms/batch 12.56 | loss  5.40 | ppl   221.41
| epoch   1 |  2600/ 2928 batches | lr 5.00 | ms/batch 12.62 | loss  5.40 | ppl   222.03
| epoch   1 |  2800/ 2928 batches | lr 5.00 | ms/batch 12.68 | loss  5.34 | ppl   208.62
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 39.47s | valid loss  5.58 | valid ppl   264.86
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2928 batches | lr 4.51 | ms/batch 12.75 | loss  5.37 | ppl   214.61
| epoch   2 |   400/ 2928 batches | lr 4.51 | ms/batch 13.61 | loss  5.40 | ppl   220.96
| epoch   2 |   600/ 2928 batches | lr 4.51 | ms/batch 13.16 | loss  5.20 | ppl   181.73
| epoch   2 |   800/ 2928 batches | lr 4.51 | ms/batch 13.10 | loss  5.28 | ppl   196.55
| epoch   2 |  1000/ 2928 batches | lr 4.51 | ms/batch 13.23 | loss  5.27 | ppl   193.49
| epoch   2 |  1200/ 2928 batches | lr 4.51 | ms/batch 13.56 | loss  5.28 | ppl   196.04
| epoch   2 |  1400/ 2928 batches | lr 4.51 | ms/batch 12.68 | loss  5.30 | ppl   200.68
| epoch   2 |  1600/ 2928 batches | lr 4.51 | ms/batch 12.65 | loss  5.34 | ppl   208.51
| epoch   2 |  1800/ 2928 batches | lr 4.51 | ms/batch 13.01 | loss  5.29 | ppl   199.06
| epoch   2 |  2000/ 2928 batches | lr 4.51 | ms/batch 14.39 | loss  5.29 | ppl   199.01
| epoch   2 |  2200/ 2928 batches | lr 4.51 | ms/batch 12.66 | loss  5.17 | ppl   175.10
| epoch   2 |  2400/ 2928 batches | lr 4.51 | ms/batch 12.67 | loss  5.28 | ppl   195.85
| epoch   2 |  2600/ 2928 batches | lr 4.51 | ms/batch 12.67 | loss  5.31 | ppl   202.56
| epoch   2 |  2800/ 2928 batches | lr 4.51 | ms/batch 12.73 | loss  5.22 | ppl   184.74
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 40.29s | valid loss  5.58 | valid ppl   266.35
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2928 batches | lr 4.29 | ms/batch 12.78 | loss  5.25 | ppl   191.13
| epoch   3 |   400/ 2928 batches | lr 4.29 | ms/batch 12.69 | loss  5.28 | ppl   196.60
| epoch   3 |   600/ 2928 batches | lr 4.29 | ms/batch 12.68 | loss  5.10 | ppl   163.39
| epoch   3 |   800/ 2928 batches | lr 4.29 | ms/batch 13.14 | loss  5.17 | ppl   176.29
| epoch   3 |  1000/ 2928 batches | lr 4.29 | ms/batch 14.30 | loss  5.13 | ppl   168.89
| epoch   3 |  1200/ 2928 batches | lr 4.29 | ms/batch 12.64 | loss  5.17 | ppl   176.30
| epoch   3 |  1400/ 2928 batches | lr 4.29 | ms/batch 12.99 | loss  5.20 | ppl   180.48
| epoch   3 |  1600/ 2928 batches | lr 4.29 | ms/batch 12.67 | loss  5.23 | ppl   186.36
| epoch   3 |  1800/ 2928 batches | lr 4.29 | ms/batch 12.91 | loss  5.19 | ppl   179.06
| epoch   3 |  2000/ 2928 batches | lr 4.29 | ms/batch 12.73 | loss  5.20 | ppl   181.16
| epoch   3 |  2200/ 2928 batches | lr 4.29 | ms/batch 12.68 | loss  5.06 | ppl   157.73
| epoch   3 |  2400/ 2928 batches | lr 4.29 | ms/batch 12.81 | loss  5.18 | ppl   178.23
| epoch   3 |  2600/ 2928 batches | lr 4.29 | ms/batch 12.82 | loss  5.20 | ppl   180.67
| epoch   3 |  2800/ 2928 batches | lr 4.29 | ms/batch 12.62 | loss  5.14 | ppl   169.99
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 39.80s | valid loss  5.55 | valid ppl   256.71
-----------------------------------------------------------------------------------------
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

6. Evaluate the model with the test dataset

test_loss = evaluate(best_model, test_data)
print('='*89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)
  • 1
  • 2
  • 3
  • 4
  • 5
=========================================================================================
| End of training | test loss  5.46 | test ppl   234.11
=========================================================================================
  • 1
  • 2
  • 3
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/901617
推荐阅读
相关标签
  

闽ICP备14008679号