当前位置:   article > 正文

Attention 伪代码实现(pytorch 版)

Attention 伪代码实现(pytorch 版)

Attention的原理已经有很多介绍了,实现的伪代码参照transformer,下面写了最简单的版本

import torch, math
from torch import nn
dropout_prob = 0.1

def forward(
        hidden_size, # d
        input, #(b, s, d)
        attention_mask  #(b, s, s)
):
    query = nn.Linear(hidden_size, hidden_size) #(d,d)
    key = nn.Linear(hidden_size, hidden_size)
    value = nn.Linear(hidden_size, hidden_size)
    dropout = nn.Dropout(dropout_prob)

    query_layer = query(input) #(b, s, d)
    key_layer = key(input)
    value_layer = value(input)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #(b, s, s)
    attention_scores = attention_scores / math.sqrt(hidden_size)
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.functional.softmax(attention_scores, dim=-1) #(b, s, s)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = dropout(attention_probs)


    outputs = torch.matmul(attention_probs, value_layer) # (b, s, s), (b, s, d) -> (b, s, d)

    return outputs

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/695985
推荐阅读
相关标签
  

闽ICP备14008679号