当前位置:   article > 正文

transformer使用示例

transformer 查询等长

关于transformer的一些基础知识,之前在看李宏毅视频的时候总结了一些,可以看here,到写此文章时,也基本忘的差不多了,故也不深究,讲两个关于transformer的基本应用,来方便理解与应用。

序列标注

参考文件transformer_postag.py.

1. 加载数据

12
#加载数据train_data, test_data, vocab, pos_vocab = load_treebank()

其中load_treebank代码:

1234567891011121314151617
def load_treebank():    # 需要翻墙下载,可以自行设置代码    nltk.set_proxy('http://192.168.0.28:1080')    # 如果没有的话那么则会下载,否则忽略    nltk.download('treebank')    from nltk.corpus import treebank    sents, postags = zip(*(zip(*sent) for sent in treebank.tagged_sents()))    vocab = Vocab.build(sents, reserved_tokens=["<pad>"])    tag_vocab = Vocab.build(postags)    train_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[:3000], postags[:3000])]    test_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[3000:], postags[3000:])]    return train_data, test_data, vocab, tag_vocab

加载后可以看到,train_datatest_data都是list,其中每一个sample都是tuple,分别是input和target。如下:

1234
>>> train_data[0]>>> Out[1]: ([2, 3, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [1, 1, 2, 3, 4, 5, 2, 6, 7, 8, 9, 10, 8, 5, 9, 1, 3, 11])

2. 数据处理

123456789
# 这个函数就是将其变成等长,填充使用<pad>,至于是0还是1还是其他值并不重要,因为还有mask~def collate_fn(examples):    lengths = torch.tensor([len(ex[0]) for ex in examples])    inputs = [torch.tensor(ex[0]) for ex in examples]    targets = [torch.tensor(ex[1]) for ex in examples]    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])    targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])    return inputs, lengths, targets, inputs != vocab["<pad>"]

3. 模型部分

123456789101112131415161718192021222324252627282930313233343536373839
class PositionalEncoding(nn.Module):    def __init__(self, d_model, dropout=0.1, max_len=512):        super(PositionalEncoding, self).__init__()        pe = torch.zeros(max_len, d_model)        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))        pe[:, 0::2] = torch.sin(position * div_term)        pe[:, 1::2] = torch.cos(position * div_term)        pe = pe.unsqueeze(0).transpose(0, 1)        self.register_buffer('pe', pe)    def forward(self, x):        x = x + self.pe[:x.size(0), :]        return xclass Transformer(nn.Module):    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):        super(Transformer, self).__init__()        # 词嵌入层        self.embedding_dim = embedding_dim        self.embeddings = nn.Embedding(vocab_size, embedding_dim)        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)        # 编码层:使用Transformer        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)        # 输出层        self.output = nn.Linear(hidden_dim, num_class)    def forward(self, inputs, lengths):        inputs = torch.transpose(inputs, 0, 1)        hidden_states = self.embeddings(inputs)        hidden_states = self.position_embedding(hidden_states)        attention_mask = length_to_mask(lengths) == False        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)        logits = self.output(hidden_states)        log_probs = F.log_softmax(logits, dim=-1)        return log_probs

这里有几点可能需要注意的:

  • PositionalEncoding

因为self attention是没有像rnn位置信息编码的,所以transformer引入了positional encoding,使用绝对位置进行编码,对每一个输入加上position信息,可以看self.pe,这个一个static lookup table。目前也出现一些使用relative positional encoding的,也就是加入相对位置编码,这个在ner任务中挺常见,比如TENERFlat-Lattice-Transformer。但是最近google证明这种相对位置编码只是引入了更多的信息特征进来

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/360976
推荐阅读
相关标签