Transformer 图的主类¶
Transformer 的处理流程可以看作是完整图中的 2 阶段消息传递(适当添加预处理和后处理):1)编码器中的自注意力,2)解码器中的自注意力,然后是交叉编码器和解码器之间的注意力机制,如下所示。
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
super(Transformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc = pos_enc
self.generator = generator
self.h, self.d_k = h, d_k
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Send weighted values to target nodes
g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
# Pre-compute queries and key-value pairs.
for pre_func, nids in pre_pairs:
g.apply_nodes(pre_func, nids)
self.propagate_attention(g, eids)
# Further calculation after attention mechanism
for post_func, nids in post_pairs:
g.apply_nodes(post_func, nids)
def forward(self, graph):
g = graph.g
nids, eids = graph.nids, graph.eids
# Word Embedding and Position Embedding
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)
for i in range(self.encoder.N):
# Step 1: Encoder Self-attention
pre_func = self.encoder.pre_func(i, 'qkv')
post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
for i in range(self.decoder.N):
# Step 2: Dncoder Self-attention
pre_func = self.decoder.pre_func(i, 'qkv')
post_func = self.decoder.post_func(i)
nodes, edges = nids['dec'], eids['dd']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
# Step 3: Encoder-Decoder attention
pre_q = self.decoder.pre_func(i, 'q', 1)
pre_kv = self.decoder.pre_func(i, 'kv', 1)
post_func = self.decoder.post_func(i, 1)
nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])
return self.generator(g.ndata['x'][nids['dec']])
笔记
通过调用update_graph
函数,您可以使用几乎相同的代码在任何子图上创建自己的 Transformer。这种灵活性使我们能够发现新的稀疏结构(参见此处提到的局部注意力)。请注意,在此实现中,您不使用掩码或填充,这使得逻辑更加清晰并节省内存。代价是实施速度较慢。