赞
踩
Attention的原理已经有很多介绍了,实现的伪代码参照transformer,下面写了最简单的版本
import torch, math from torch import nn dropout_prob = 0.1 def forward( hidden_size, # d input, #(b, s, d) attention_mask #(b, s, s) ): query = nn.Linear(hidden_size, hidden_size) #(d,d) key = nn.Linear(hidden_size, hidden_size) value = nn.Linear(hidden_size, hidden_size) dropout = nn.Dropout(dropout_prob) query_layer = query(input) #(b, s, d) key_layer = key(input) value_layer = value(input) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #(b, s, s) attention_scores = attention_scores / math.sqrt(hidden_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) #(b, s, s) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = dropout(attention_probs) outputs = torch.matmul(attention_probs, value_layer) # (b, s, s), (b, s, d) -> (b, s, d) return outputs
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。