当前位置:   article > 正文

Transformer 作为图神经网络_图神经网络transformer模型

图神经网络transformer模型

Transformer 作为图神经网络

作者:叶子豪、周金晶、郭启鹏、甘泉、张政

警告

本教程旨在通过代码作为解释手段,深入了解本文。因此,该实现并未针对运行效率进行优化。推荐实现请参考官方示例

在本教程中,您将了解 Transformer 模型的简化实现。您可以看到最重要的设计点的亮点。例如,只有单头注意力。完整的代码可以在这里找到 。

整体结构与研究论文Annotated Transformer中的结构类似。

Transformer 模型作为序列建模的 CNN/RNN 架构的替代品,在研究论文《Attention is All You Need》中被介绍。它提高了机器翻译和自然语言推理任务(GPT)的技术水平。最近使用大规模语料库(BERT)预训练 Transformer 的工作表明它能够学习高质量的语义表示。

Transformer 的有趣之处在于它对注意力的广泛运用。注意力的经典用法来自机器翻译模型,其中输出标记关注所有输入标记。

Transformer 另外在解码器和编码器中应用了自注意力。这个过程迫使彼此相关的单词组合在一起,无论它们在序列中的位置如何。这与基于 RNN 的模型不同,在 RNN 模型中,单词(在源句子中)沿着链组合,这被认为过于受限。

Transformer的Attention层

在 Transformer 的注意力层中,对于每个节点,模块学习为其传入边缘分配权重。对于节点对(i,j)(从ij)与节点 xi,xjRn,它们的连接分数定义如下:

在哪里Wq,Wk,WvRn×dk绘制表示图x�分别为“查询”、“键”和“值”空间。

还有其他可能性来实现评分功能。点积衡量给定查询的相似度qj��和一把钥匙 ki��: 如果j�需要存储在的信息i�,位置处的查询向量j�(qj��) 应该接近位置处的关键向量i�(ki��)。

然后使用分数来计算输入值的总和,对边的权重进行归一化,存储在wv。然后应用仿射层wv得到输出 o

多头注意力层

在 Transformer 中,注意力是多头的。头部非常类似于卷积网络中的通道。多头注意力由多个注意力头组成,其中每个头指单个注意力模块。wv(i)wv(�)对于所有头都连接并映射到输出o�具有仿射层:

o=Woconcat([wv(0),wv(1),,wv(h)])�=��⋅concat([wv(0),wv(1),⋯,wv(ℎ)])

下面的代码包装了多头注意力的必要组件,并提供了两个接口。

  • get将状态“x”映射到查询、键和值,这是以下步骤(propagate_attention)所必需的。

  • get_o将关注后的更新值映射到输出 o�用于后处理。

class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"
    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
        self.linears = clones(nn.Linear(dim_model, dim_model), 4)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if 'k' in fields:
            ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if 'v' in fields:
            ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))

Copy to clipboard

DGL 如何使用图神经网络实现 Transformer

通过将注意力视为图中的边并采用在边上传递消息来引发适当的处理,您可以获得 Transformer 的不同视角。

图结构

通过将源句子和目标句子的标记映射到节点来构建图。完整的 Transformer 图由三个子图组成:

源语言图。这是一个完整的图,每个tokensi��可以参与任何其他令牌sj��(包括自循环)。

图片0

 目标语言图。该图是半完整的,因为ti��只参加tj��如果 i>j�>�(输出标记不能依赖于未来的单词)。

图片1

 跨语言图。这是一个二部图,其中每个源标记都有一条边si��每个目标代币 tj��,这意味着每个目标代币都可以参与源代币。 

图片2

完整的图片如下所示:

图3

在数据集准备阶段预先构建图表。

消息传递

定义图形结构后,继续定义消息传递的计算。

假设您已经计算了所有查询qi��, 键 ki��和价值观vi��。对于每个节点i�(无论是源token还是目标token),你可以将注意力计算分解为两个步骤:

  1. 消息计算:计算注意力分数 scoreijscore��之间i�和所有节点j� 参加,通过采取之间的缩放点积 qi��和kj��。消息发送自j�到 i�将由分数组成scoreijscore��和价值vj��。

  2. 消息聚合:聚合值vj��来自所有 j�根据分数scoreijscore��。

实施简单
消息计算

计算score并将源节点发送v到目标邮箱

def message_func(edges):
    return {'score': ((edges.src['k'] * edges.dst['q'])
                      .sum(-1, keepdim=True)),
            'v': edges.src['v']}

Copy to clipboard

消息聚合

对所有入边和加权和进行归一化以获得输出

import torch as th
import torch.nn.functional as F

def reduce_func(nodes, d_k=64):
    v = nodes.mailbox['v']
    att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
    return {'dx': (att * v).sum(1)}

Copy to clipboard

在特定边缘执行
import functools.partial as partial
def naive_propagate_attention(self, g, eids):
    g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))

Copy to clipboard

使用内置函数加速

要加快消息传递过程,请使用 DGL 的内置函数,包括:

  • fn.src_mul_egdes(src_field, edges_field, out_field)将源节点的属性和边属性相乘,并将结果发送到以 为键控的目标节点的邮箱out_field

  • fn.copy_edge(edges_field, out_field)将边的属性复制到目标节点的邮箱。

  • fn.sum(edges_field, out_field)总结边缘的属性并将聚合发送到目标节点的邮箱。

这里,你将这些内置函数组装成propagate_attention,这也是最终实现中主要的图操作函数。要加速它,请将softmax操作分为以下步骤。回想一下,每个头都有两个阶段。

  1. 通过将 src 节点k和 dst 节点 相乘来计算注意力分数q

    • g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)

  2. 在所有 dst 节点的传入边上缩放 Softmax

    • 第 1 步:使用尺度归一化常数对分数进行指数化

      • g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))

        scoreijexp(scoreijdk−−√)score��←exp⁡(score����)
    • 步骤2:获取关联节点上的“值”,并根据每个节点的传入边上的“分数”进行加权;获取每个节点的传入边缘的“分数”总和以进行标准化。请注意,这里 wvwv没有标准化。

      • msg: fn.src_mul_edge('v', 'score', 'v'), reduce: fn.sum('v', 'wv')

        wvj=i=1Nscoreijviwv�=∑�=1�score��⋅��
      • msg: fn.copy_edge('score', 'score'), reduce: fn.sum('score', 'z')

        zj=i=1Nscoreijz�=∑�=1�score��

的正常化wvwv留待后期处理。

def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


def propagate_attention(self, g, eids):
    # Compute attention score
    g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
    g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
    # Update node state
    g.send_and_recv(eids,
                    [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                    [fn.sum('v', 'wv'), fn.sum('score', 'z')])

Copy to clipboard

预处理和后处理

在Transformer中,数据在函数运行之前和之后都需要进行预处理和后处理propagate_attention

预处理预处理函数pre_func首先对节点表示进行标准化,然后将它们映射到一组查询、键和值,以自注意力为例:

xLayerNorm(x)[q,k,v][Wq,Wk,Wv]x�←LayerNorm(�)[�,�,�]←[��,��,��]⋅�

后处理 后处理函数post_funcs完成对应变压器一层的整个计算: 1. 归一化wvwv并得到多头注意力层的输出o�。

wvwvzoWowv+bowv←wv��←��⋅wv+��

添加剩余连接:

xx+o�←�+�
  1. 应用两层位置前馈层x� 然后添加剩余连接:

    xx+LayerNorm(FFN(x))�←�+LayerNorm(FFN(�))

    在哪里FFNFFN指的是前馈函数。

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func

    def post_func(self, i):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv', l=0):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            if fields == 'kv':
                norm_x = x # In enc-dec attention, x has already been normalized.
            else:
                norm_x = layer.sublayer[l].norm(x)
            return layer.self_attn.get(norm_x, fields)
        return func

    def post_func(self, i, l=0):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

Copy to clipboard

这样就完成了Transformer中一层编码器和解码器的所有流程。

笔记

子层连接部分与原始论文略有不同。但是,此实现与The Annotated Transformer 和 OpenNMT相同。

Transformer 图的主类

Transformer 的处理流程可以看作是完整图中的 2 阶段消息传递(适当添加预处理和后处理):1)编码器中的自注意力,2)解码器中的自注意力,然后是交叉编码器和解码器之间的注意力机制,如下所示。

图4

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
        super(Transformer, self).__init__()
        self.encoder, self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc = pos_enc
        self.generator = generator
        self.h, self.d_k = h, d_k

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."

        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        nids, eids = graph.nids, graph.eids

        # Word Embedding and Position Embedding
        src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
        tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
        g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
        g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)

        for i in range(self.encoder.N):
            # Step 1: Encoder Self-attention
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            nodes, edges = nids['enc'], eids['ee']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])

        for i in range(self.decoder.N):
            # Step 2: Dncoder Self-attention
            pre_func = self.decoder.pre_func(i, 'qkv')
            post_func = self.decoder.post_func(i)
            nodes, edges = nids['dec'], eids['dd']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            # Step 3: Encoder-Decoder attention
            pre_q = self.decoder.pre_func(i, 'q', 1)
            pre_kv = self.decoder.pre_func(i, 'kv', 1)
            post_func = self.decoder.post_func(i, 1)
            nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
            self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])

        return self.generator(g.ndata['x'][nids['dec']])

Copy to clipboard

笔记

通过调用update_graph函数,您可以使用几乎相同的代码在任何子图上创建自己的 Transformer。这种灵活性使我们能够发现新的稀疏结构(参见此处提到的局部注意力)。请注意,在此实现中,您不使用掩码或填充,这使得逻辑更加清晰并节省内存。代价是实施速度较慢。

训练

本教程不涉及原论文中提到的其他几种技术,例如标签平滑和 Noam 优化。有关这些模块的详细描述,请阅读 哈佛 NLP 团队编写的The Annotated Transformer 。

任务和数据集

Transformer 是各种 NLP 任务的通用框架。本教程重点介绍序列到序列学习:通过一个典型案例来说明其工作原理。

至于数据集,有两个示例任务:复制和排序,以及两个现实世界的翻译任务:multi30k en-de 任务和 wmt14 en-de 任务。

  • 复制数据集:将输入序列复制到输出。(训练/有效/测试:9000、1000、1000)

  • 对数据集进行排序:对输入序列进行排序作为输出。(训练/有效/测试:9000、1000、1000)

  • Multi30k en-de,将句子从 En 翻译为 De。(训练/有效/测试:29000, 1000, 1000)

  • WMT14 en-de,将句子从 En 翻译为 De。(训练/有效/测试:4500966/3000/3003)

笔记

使用 wmt14 进行训练需要多 GPU 支持,并且不可用。欢迎贡献!

图表构建

批处理这与处理 Tree-LSTM 的方式类似。提前构建一个图池,包括输入长度和输出长度的所有可能的组合。然后,对于批次中的每个样本,调用dgl.batch其大小的批次图一起形成一个大图。

您可以将创建图池和构建BatchedGraph的过程包装在dataset.GraphPool和 中dataset.TranslationDataset

graph_pool = GraphPool()

data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
    print(graph.nids['enc']) # encoder node ids
    print(graph.nids['dec']) # decoder node ids
    print(graph.eids['ee']) # encoder-encoder edge ids
    print(graph.eids['ed']) # encoder-decoder edge ids
    print(graph.eids['dd']) # decoder-decoder edge ids
    print(graph.src[0]) # Input word index list
    print(graph.src[1]) # Input positions
    print(graph.tgt[0]) # Output word index list
    print(graph.tgt[1]) # Ouptut positions
    break

Copy to clipboard

输出:

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')
tensor([ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
         95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
        109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
        123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
        137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
        151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
        165, 166, 167, 168, 169, 170], device='cuda:0')
tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
        199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
        213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],
       device='cuda:0')
tensor([28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 0, 28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')

Copy to clipboard

把它们放在一起

在复制任务上训练一层、128 维的单头 Transformer。将其他参数设置为默认值。

本教程不包含推理模块。它需要波束搜索。有关完整的实现,请参阅GitHub 存储库

from tqdm import tqdm
import torch as th
import numpy as np

from loss import LabelSmoothing, SimpleLossCompute
from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool

def run_epoch(data_iter, model, loss_compute, is_train=True):
    for i, g in tqdm(enumerate(data_iter)):
        with th.set_grad_enabled(is_train):
            output = model(g)
            loss = loss_compute(output, g.tgt_y, g.n_tokens)
    print('average loss: {}'.format(loss_compute.avg_loss))
    print('accuracy: {}'.format(loss_compute.accuracy))

N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']

dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128

# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
                    th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute

att_maps = []
for epoch in range(4):
    train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
    valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
    print('Epoch: {} Training...'.format(epoch))
    model.train(True)
    run_epoch(train_iter, model,
              loss_compute(criterion, model_opt), is_train=True)
    print('Epoch: {} Evaluating...'.format(epoch))
    model.att_weight_map = None
    model.eval()
    run_epoch(valid_iter, model,
              loss_compute(criterion, None), is_train=False)
    att_maps.append(model.att_weight_map)

Copy to clipboard

可视化

训练后,您可以可视化 Transformer 在复制任务上产生的注意力。

src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
# visualize head 0 of encoder-decoder attention
att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)

Copy to clipboard

图5

从图中您可以看到解码器节点逐渐学会关注输入序列中的相应节点,这是预期的行为。

多头注意力

除了在玩具任务上训练的单头注意力之外。我们还可视化了在多个 30k 数据集上训练的单层 Transformer 网络的编码器自注意力、解码器自注意力和编码器-解码器注意力的注意力分数。

从可视化中,您可以看到不同头部的多样性,这正是您所期望的。不同的大脑学习单词对之间的不同关系。

  • 编码器自注意力 

    图6

  • 编码器-解码器注意力目标序列中的大多数单词都会关注源序列中与其相关的单词,例如:当生成“See”(De)时,多个头会关注“lake”;在生成“Eisfischerhütte”时,多个负责人会参与“ice”。

    图片7

  • 解码器自注意力大多数单词都会关注它们之前的几个单词。

    图片8

自适应通用变压器

谷歌最近的一篇研究论文Universal Transformer就是一个展示如何update_graph适应更复杂的更新规则的例子。

Universal Transformer 的提出是为了解决 vanilla Transformer 在计算上不通用的问题,通过在 Transformer 中引入递归:

  • 通用变换器的基本思想是通过在表示上应用变换器层,在每个循环步骤中重复修改序列中所有符号的表示。

  • 与普通 Transformer 相比,Universal Transformer 在其层之间共享权重,并且它不固定重复时间(这意味着 Transformer 中的层数)。

进一步的优化采用自适应计算时间(ACT)机制来允许模型动态调整序列中每个位置的表示被修改的次数(以下称为步骤 )。该模型也称为自适应通用变压器 (AUT)。

在 AUT 中,您维护一个活动节点列表。在每一步中t�,我们计算停止概率:h(0<h<1)ℎ(0<ℎ<1)对于此列表中的所有节点:

hti=σ(Whxti+bh)ℎ��=�(�ℎ���+�ℎ)

然后动态决定哪些节点仍然处于活动状态。节点在某个时间停止T�当且仅当 T1t=1ht<1εTt=1ht∑�=1�−1ℎ�<1−�≤∑�=1�ℎ�。暂停的节点将从列表中删除。该过程继续进行,直到列表为空或达到预定义的最大步长。从 DGL 的角度来看,这意味着“活动”图随着时间的推移变得越来越稀疏。

节点的最终状态si��是加权平均值 xti���经过htiℎ��:

si=t=1Thtixti��=∑�=1�ℎ��⋅���

update_graph在 DGL 中,通过调用仍处于活动状态的节点以及与该节点关联的边来实现算法 。以下代码显示了 DGL 中的通用转换器类:

class UTransformer(nn.Module):
    "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
    MAX_DEPTH = 8
    thres = 0.99
    act_loss_weight = 0.01
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
        super(UTransformer, self).__init__()
        self.encoder,  self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc, self.time_enc = pos_enc, time_enc
        self.halt_enc = HaltingUnit(h * d_k)
        self.halt_dec = HaltingUnit(h * d_k)
        self.generator = generator
        self.h, self.d_k = h, d_k

    def step_forward(self, nodes):
        # add positional encoding and time encoding, increment step by one
        x = nodes.data['x']
        step = nodes.data['step']
        pos = nodes.data['pos']
        return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
                'step': step + 1}

    def halt_and_accum(self, name, end=False):
        "field: 'enc' or 'dec'"
        halt = self.halt_enc if name == 'enc' else self.halt_dec
        thres = self.thres
        def func(nodes):
            p = halt(nodes.data['x'])
            sum_p = nodes.data['sum_p'] + p
            active = (sum_p < thres) & (1 - end)
            _continue = active.float()
            r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
            s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
            return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
        return func

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."
        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
        g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
        g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
        g.nodes[nids['enc']].data['pos'] = graph.src[1]
        g.nodes[nids['dec']].data['pos'] = graph.tgt[1]

        # init step
        device = next(self.parameters()).device
        g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)    # accumulated state
        g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)                    # halting prob
        g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)                     # remainder
        g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)                # sum of pondering values
        g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)                  # step
        g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)                # active

        for step in range(self.MAX_DEPTH):
            pre_func = self.encoder.pre_func('qkv')
            post_func = self.encoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])

        g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])

        for step in range(self.MAX_DEPTH):
            pre_func = self.decoder.pre_func('qkv')
            post_func = self.decoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes)])

            pre_q = self.decoder.pre_func('q', 1)
            pre_kv = self.decoder.pre_func('kv', 1)
            post_func = self.decoder.post_func(1)
            nodes_e = nids['enc']
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(pre_q, nodes), (pre_kv, nodes_e)],
                              [(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])

        g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
        act_loss = th.mean(g.ndata['r']) # ACT loss

        return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight

Copy to clipboard

调用filter_nodesfilter_edge来查找仍处于活动状态的节点/边:

笔记

  • filter_nodes()将谓词和节点 ID 列表/张量作为输入,然后返回满足给定谓词的节点 ID 张量。

  • filter_edges()将谓词和边 ID 列表/张量作为输入,然后返回满足给定谓词的边 ID 张量。

有关完整的实现,请参阅GitHub 存储库

下图展示了Adaptive Computational Time的效果。句子的不同位置被修改了不同的时间。

图片9

您还可以可视化 AUT 在排序任务上训练期间节点上步数分布的动态(达到 99.7% 的准确率),这演示了 AUT 如何在训练期间学习减少重复步数。

图片10

笔记

由于许多依赖项,笔记本本身不可执行。下载7_transformer.py,并将 python 脚本复制到目录中examples/pytorch/transformer ,然后运行以查看它是如何工作的。python 7_transformer.py

脚本总运行时间:(0分0.000秒)

Transformer 作为图神经网络

作者:叶子豪、周金晶、郭启鹏、甘泉、张政

警告

本教程旨在通过代码作为解释手段,深入了解本文。因此,该实现并未针对运行效率进行优化。推荐实现请参考官方示例

在本教程中,您将了解 Transformer 模型的简化实现。您可以看到最重要的设计点的亮点。例如,只有单头注意力。完整的代码可以在这里找到 。

整体结构与研究论文Annotated Transformer中的结构类似。

Transformer 模型作为序列建模的 CNN/RNN 架构的替代品,在研究论文《Attention is All You Need》中被介绍。它提高了机器翻译和自然语言推理任务(GPT)的技术水平。最近使用大规模语料库(BERT)预训练 Transformer 的工作表明它能够学习高质量的语义表示。

Transformer 的有趣之处在于它对注意力的广泛运用。注意力的经典用法来自机器翻译模型,其中输出标记关注所有输入标记。

Transformer 另外在解码器和编码器中应用了自注意力。这个过程迫使彼此相关的单词组合在一起,无论它们在序列中的位置如何。这与基于 RNN 的模型不同,在 RNN 模型中,单词(在源句子中)沿着链组合,这被认为过于受限。

Transformer的Attention层

在 Transformer 的注意力层中,对于每个节点,模块学习为其传入边缘分配权重。对于节点对(i,j)(�,�) (从i�到j�)与节点 xi,xjRn��,��∈��,它们的连接分数定义如下:

qj=Wqxjki=Wkxivi=Wvxiscore=qTjki��=��⋅����=��⋅����=��⋅��score=�����

在哪里Wq,Wk,WvRn×dk��,��,��∈��×��绘制表示图x�分别为“查询”、“键”和“值”空间。

还有其他可能性来实现评分功能。点积衡量给定查询的相似度qj��和一把钥匙 ki��: 如果j�需要存储在的信息i�,位置处的查询向量j�(qj��) 应该接近位置处的关键向量i�(ki��)。

然后使用分数来计算输入值的总和,对边的权重进行归一化,存储在wvwv。然后应用仿射层wvwv得到输出 o�:

wji=exp{scoreji}(k,i)Eexp{scoreki}wvi=(k,i)Ewkivko=Wowv���=exp⁡{score��}∑(�,�)∈�exp⁡{score��}wv�=∑(�,�)∈�������=��⋅wv

多头注意力层

在 Transformer 中,注意力是多头的。头部非常类似于卷积网络中的通道。多头注意力由多个注意力头组成,其中每个头指单个注意力模块。wv(i)wv(�)对于所有头都连接并映射到输出o�具有仿射层:

o=Woconcat([wv(0),wv(1),,wv(h)])�=��⋅concat([wv(0),wv(1),⋯,wv(ℎ)])

下面的代码包装了多头注意力的必要组件,并提供了两个接口。

  • get将状态“x”映射到查询、键和值,这是以下步骤(propagate_attention)所必需的。

  • get_o将关注后的更新值映射到输出 o�用于后处理。

class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"
    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
        self.linears = clones(nn.Linear(dim_model, dim_model), 4)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if 'k' in fields:
            ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if 'v' in fields:
            ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))

Copy to clipboard

DGL 如何使用图神经网络实现 Transformer

通过将注意力视为图中的边并采用在边上传递消息来引发适当的处理,您可以获得 Transformer 的不同视角。

图结构

通过将源句子和目标句子的标记映射到节点来构建图。完整的 Transformer 图由三个子图组成:

源语言图。这是一个完整的图,每个tokensi��可以参与任何其他令牌sj��(包括自循环)。

图片0

 目标语言图。该图是半完整的,因为ti��只参加tj��如果 i>j�>�(输出标记不能依赖于未来的单词)。

图片1

 跨语言图。这是一个二部图,其中每个源标记都有一条边si��每个目标代币 tj��,这意味着每个目标代币都可以参与源代币。 

图片2

完整的图片如下所示:

图3

在数据集准备阶段预先构建图表。

消息传递

定义图形结构后,继续定义消息传递的计算。

假设您已经计算了所有查询qi��, 键 ki��和价值观vi��。对于每个节点i�(无论是源token还是目标token),你可以将注意力计算分解为两个步骤:

  1. 消息计算:计算注意力分数 scoreijscore��之间i�和所有节点j� 参加,通过采取之间的缩放点积 qi��和kj��。消息发送自j�到 i�将由分数组成scoreijscore��和价值vj��。

  2. 消息聚合:聚合值vj��来自所有 j�根据分数scoreijscore��。

实施简单
消息计算

计算score并将源节点发送v到目标邮箱

def message_func(edges):
    return {'score': ((edges.src['k'] * edges.dst['q'])
                      .sum(-1, keepdim=True)),
            'v': edges.src['v']}

Copy to clipboard

消息聚合

对所有入边和加权和进行归一化以获得输出

import torch as th
import torch.nn.functional as F

def reduce_func(nodes, d_k=64):
    v = nodes.mailbox['v']
    att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
    return {'dx': (att * v).sum(1)}

Copy to clipboard

在特定边缘执行
import functools.partial as partial
def naive_propagate_attention(self, g, eids):
    g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))

Copy to clipboard

使用内置函数加速

要加快消息传递过程,请使用 DGL 的内置函数,包括:

  • fn.src_mul_egdes(src_field, edges_field, out_field)将源节点的属性和边属性相乘,并将结果发送到以 为键控的目标节点的邮箱out_field

  • fn.copy_edge(edges_field, out_field)将边的属性复制到目标节点的邮箱。

  • fn.sum(edges_field, out_field)总结边缘的属性并将聚合发送到目标节点的邮箱。

这里,你将这些内置函数组装成propagate_attention,这也是最终实现中主要的图操作函数。要加速它,请将softmax操作分为以下步骤。回想一下,每个头都有两个阶段。

  1. 通过将 src 节点k和 dst 节点 相乘来计算注意力分数q

    • g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)

  2. 在所有 dst 节点的传入边上缩放 Softmax

    • 第 1 步:使用尺度归一化常数对分数进行指数化

      • g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))

        scoreijexp(scoreijdk−−√)score��←exp⁡(score����)
    • 步骤2:获取关联节点上的“值”,并根据每个节点的传入边上的“分数”进行加权;获取每个节点的传入边缘的“分数”总和以进行标准化。请注意,这里 wvwv没有标准化。

      • msg: fn.src_mul_edge('v', 'score', 'v'), reduce: fn.sum('v', 'wv')

        wvj=i=1Nscoreijviwv�=∑�=1�score��⋅��
      • msg: fn.copy_edge('score', 'score'), reduce: fn.sum('score', 'z')

        zj=i=1Nscoreijz�=∑�=1�score��

的正常化wvwv留待后期处理。

def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


def propagate_attention(self, g, eids):
    # Compute attention score
    g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
    g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
    # Update node state
    g.send_and_recv(eids,
                    [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                    [fn.sum('v', 'wv'), fn.sum('score', 'z')])

Copy to clipboard

预处理和后处理

在Transformer中,数据在函数运行之前和之后都需要进行预处理和后处理propagate_attention

预处理预处理函数pre_func首先对节点表示进行标准化,然后将它们映射到一组查询、键和值,以自注意力为例:

xLayerNorm(x)[q,k,v][Wq,Wk,Wv]x�←LayerNorm(�)[�,�,�]←[��,��,��]⋅�

后处理 后处理函数post_funcs完成对应变压器一层的整个计算: 1. 归一化wvwv并得到多头注意力层的输出o�。

wvwvzoWowv+bowv←wv��←��⋅wv+��

添加剩余连接:

xx+o�←�+�
  1. 应用两层位置前馈层x� 然后添加剩余连接:

    xx+LayerNorm(FFN(x))�←�+LayerNorm(FFN(�))

    在哪里FFNFFN指的是前馈函数。

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func

    def post_func(self, i):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv', l=0):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            if fields == 'kv':
                norm_x = x # In enc-dec attention, x has already been normalized.
            else:
                norm_x = layer.sublayer[l].norm(x)
            return layer.self_attn.get(norm_x, fields)
        return func

    def post_func(self, i, l=0):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

Copy to clipboard

这样就完成了Transformer中一层编码器和解码器的所有流程。

笔记

子层连接部分与原始论文略有不同。但是,此实现与The Annotated Transformer 和 OpenNMT相同。

Transformer 图的主类

Transformer 的处理流程可以看作是完整图中的 2 阶段消息传递(适当添加预处理和后处理):1)编码器中的自注意力,2)解码器中的自注意力,然后是交叉编码器和解码器之间的注意力机制,如下所示。

图4

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
        super(Transformer, self).__init__()
        self.encoder, self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc = pos_enc
        self.generator = generator
        self.h, self.d_k = h, d_k

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."

        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        nids, eids = graph.nids, graph.eids

        # Word Embedding and Position Embedding
        src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
        tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
        g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
        g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)

        for i in range(self.encoder.N):
            # Step 1: Encoder Self-attention
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            nodes, edges = nids['enc'], eids['ee']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])

        for i in range(self.decoder.N):
            # Step 2: Dncoder Self-attention
            pre_func = self.decoder.pre_func(i, 'qkv')
            post_func = self.decoder.post_func(i)
            nodes, edges = nids['dec'], eids['dd']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            # Step 3: Encoder-Decoder attention
            pre_q = self.decoder.pre_func(i, 'q', 1)
            pre_kv = self.decoder.pre_func(i, 'kv', 1)
            post_func = self.decoder.post_func(i, 1)
            nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
            self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])

        return self.generator(g.ndata['x'][nids['dec']])

Copy to clipboard

笔记

通过调用update_graph函数,您可以使用几乎相同的代码在任何子图上创建自己的 Transformer。这种灵活性使我们能够发现新的稀疏结构(参见此处提到的局部注意力)。请注意,在此实现中,您不使用掩码或填充,这使得逻辑更加清晰并节省内存。代价是实施速度较慢。

训练

本教程不涉及原论文中提到的其他几种技术,例如标签平滑和 Noam 优化。有关这些模块的详细描述,请阅读 哈佛 NLP 团队编写的The Annotated Transformer 。

任务和数据集

Transformer 是各种 NLP 任务的通用框架。本教程重点介绍序列到序列学习:通过一个典型案例来说明其工作原理。

至于数据集,有两个示例任务:复制和排序,以及两个现实世界的翻译任务:multi30k en-de 任务和 wmt14 en-de 任务。

  • 复制数据集:将输入序列复制到输出。(训练/有效/测试:9000、1000、1000)

  • 对数据集进行排序:对输入序列进行排序作为输出。(训练/有效/测试:9000、1000、1000)

  • Multi30k en-de,将句子从 En 翻译为 De。(训练/有效/测试:29000, 1000, 1000)

  • WMT14 en-de,将句子从 En 翻译为 De。(训练/有效/测试:4500966/3000/3003)

笔记

使用 wmt14 进行训练需要多 GPU 支持,并且不可用。欢迎贡献!

图表构建

批处理这与处理 Tree-LSTM 的方式类似。提前构建一个图池,包括输入长度和输出长度的所有可能的组合。然后,对于批次中的每个样本,调用dgl.batch其大小的批次图一起形成一个大图。

您可以将创建图池和构建BatchedGraph的过程包装在dataset.GraphPool和 中dataset.TranslationDataset

graph_pool = GraphPool()

data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
    print(graph.nids['enc']) # encoder node ids
    print(graph.nids['dec']) # decoder node ids
    print(graph.eids['ee']) # encoder-encoder edge ids
    print(graph.eids['ed']) # encoder-decoder edge ids
    print(graph.eids['dd']) # decoder-decoder edge ids
    print(graph.src[0]) # Input word index list
    print(graph.src[1]) # Input positions
    print(graph.tgt[0]) # Output word index list
    print(graph.tgt[1]) # Ouptut positions
    break

Copy to clipboard

输出:

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')
tensor([ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
         95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
        109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
        123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
        137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
        151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
        165, 166, 167, 168, 169, 170], device='cuda:0')
tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
        199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
        213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],
       device='cuda:0')
tensor([28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 0, 28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')

Copy to clipboard

把它们放在一起

在复制任务上训练一层、128 维的单头 Transformer。将其他参数设置为默认值。

本教程不包含推理模块。它需要波束搜索。有关完整的实现,请参阅GitHub 存储库

from tqdm import tqdm
import torch as th
import numpy as np

from loss import LabelSmoothing, SimpleLossCompute
from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool

def run_epoch(data_iter, model, loss_compute, is_train=True):
    for i, g in tqdm(enumerate(data_iter)):
        with th.set_grad_enabled(is_train):
            output = model(g)
            loss = loss_compute(output, g.tgt_y, g.n_tokens)
    print('average loss: {}'.format(loss_compute.avg_loss))
    print('accuracy: {}'.format(loss_compute.accuracy))

N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']

dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128

# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
                    th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute

att_maps = []
for epoch in range(4):
    train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
    valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
    print('Epoch: {} Training...'.format(epoch))
    model.train(True)
    run_epoch(train_iter, model,
              loss_compute(criterion, model_opt), is_train=True)
    print('Epoch: {} Evaluating...'.format(epoch))
    model.att_weight_map = None
    model.eval()
    run_epoch(valid_iter, model,
              loss_compute(criterion, None), is_train=False)
    att_maps.append(model.att_weight_map)

Copy to clipboard

可视化

训练后,您可以可视化 Transformer 在复制任务上产生的注意力。

src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
# visualize head 0 of encoder-decoder attention
att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)

Copy to clipboard

图5

从图中您可以看到解码器节点逐渐学会关注输入序列中的相应节点,这是预期的行为。

多头注意力

除了在玩具任务上训练的单头注意力之外。我们还可视化了在多个 30k 数据集上训练的单层 Transformer 网络的编码器自注意力、解码器自注意力和编码器-解码器注意力的注意力分数。

从可视化中,您可以看到不同头部的多样性,这正是您所期望的。不同的大脑学习单词对之间的不同关系。

  • 编码器自注意力 

    图6

  • 编码器-解码器注意力目标序列中的大多数单词都会关注源序列中与其相关的单词,例如:当生成“See”(De)时,多个头会关注“lake”;在生成“Eisfischerhütte”时,多个负责人会参与“ice”。

    图片7

  • 解码器自注意力大多数单词都会关注它们之前的几个单词。

    图片8

自适应通用变压器

谷歌最近的一篇研究论文Universal Transformer就是一个展示如何update_graph适应更复杂的更新规则的例子。

Universal Transformer 的提出是为了解决 vanilla Transformer 在计算上不通用的问题,通过在 Transformer 中引入递归:

  • 通用变换器的基本思想是通过在表示上应用变换器层,在每个循环步骤中重复修改序列中所有符号的表示。

  • 与普通 Transformer 相比,Universal Transformer 在其层之间共享权重,并且它不固定重复时间(这意味着 Transformer 中的层数)。

进一步的优化采用自适应计算时间(ACT)机制来允许模型动态调整序列中每个位置的表示被修改的次数(以下称为步骤 )。该模型也称为自适应通用变压器 (AUT)。

在 AUT 中,您维护一个活动节点列表。在每一步中t�,我们计算停止概率:h(0<h<1)ℎ(0<ℎ<1)对于此列表中的所有节点:

hti=σ(Whxti+bh)ℎ��=�(�ℎ���+�ℎ)

然后动态决定哪些节点仍然处于活动状态。节点在某个时间停止T�当且仅当 T1t=1ht<1εTt=1ht∑�=1�−1ℎ�<1−�≤∑�=1�ℎ�。暂停的节点将从列表中删除。该过程继续进行,直到列表为空或达到预定义的最大步长。从 DGL 的角度来看,这意味着“活动”图随着时间的推移变得越来越稀疏。

节点的最终状态si��是加权平均值 xti���经过htiℎ��:

si=t=1Thtixti��=∑�=1�ℎ��⋅���

update_graph在 DGL 中,通过调用仍处于活动状态的节点以及与该节点关联的边来实现算法 。以下代码显示了 DGL 中的通用转换器类:

class UTransformer(nn.Module):
    "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
    MAX_DEPTH = 8
    thres = 0.99
    act_loss_weight = 0.01
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
        super(UTransformer, self).__init__()
        self.encoder,  self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc, self.time_enc = pos_enc, time_enc
        self.halt_enc = HaltingUnit(h * d_k)
        self.halt_dec = HaltingUnit(h * d_k)
        self.generator = generator
        self.h, self.d_k = h, d_k

    def step_forward(self, nodes):
        # add positional encoding and time encoding, increment step by one
        x = nodes.data['x']
        step = nodes.data['step']
        pos = nodes.data['pos']
        return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
                'step': step + 1}

    def halt_and_accum(self, name, end=False):
        "field: 'enc' or 'dec'"
        halt = self.halt_enc if name == 'enc' else self.halt_dec
        thres = self.thres
        def func(nodes):
            p = halt(nodes.data['x'])
            sum_p = nodes.data['sum_p'] + p
            active = (sum_p < thres) & (1 - end)
            _continue = active.float()
            r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
            s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
            return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
        return func

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."
        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
        g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
        g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
        g.nodes[nids['enc']].data['pos'] = graph.src[1]
        g.nodes[nids['dec']].data['pos'] = graph.tgt[1]

        # init step
        device = next(self.parameters()).device
        g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)    # accumulated state
        g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)                    # halting prob
        g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)                     # remainder
        g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)                # sum of pondering values
        g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)                  # step
        g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)                # active

        for step in range(self.MAX_DEPTH):
            pre_func = self.encoder.pre_func('qkv')
            post_func = self.encoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])

        g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])

        for step in range(self.MAX_DEPTH):
            pre_func = self.decoder.pre_func('qkv')
            post_func = self.decoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes)])

            pre_q = self.decoder.pre_func('q', 1)
            pre_kv = self.decoder.pre_func('kv', 1)
            post_func = self.decoder.post_func(1)
            nodes_e = nids['enc']
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(pre_q, nodes), (pre_kv, nodes_e)],
                              [(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])

        g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
        act_loss = th.mean(g.ndata['r']) # ACT loss

        return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight

Copy to clipboard

调用filter_nodesfilter_edge来查找仍处于活动状态的节点/边:

笔记

  • filter_nodes()将谓词和节点 ID 列表/张量作为输入,然后返回满足给定谓词的节点 ID 张量。

  • filter_edges()将谓词和边 ID 列表/张量作为输入,然后返回满足给定谓词的边 ID 张量。

有关完整的实现,请参阅GitHub 存储库

下图展示了Adaptive Computational Time的效果。句子的不同位置被修改了不同的时间。

图片9

您还可以可视化 AUT 在排序任务上训练期间节点上步数分布的动态(达到 99.7% 的准确率),这演示了 AUT 如何在训练期间学习减少重复步数。

图片10

笔记

由于许多依赖项,笔记本本身不可执行。下载7_transformer.py,并将 python 脚本复制到目录中examples/pytorch/transformer ,然后运行以查看它是如何工作的。python 7_transformer.py

脚本总运行时间:(0分0.000秒)

Transformer as a Graph Neural Network

Author: Zihao Ye, Jinjing Zhou, Qipeng Guo, Quan Gan, Zheng Zhang

Warning

The tutorial aims at gaining insights into the paper, with code as a mean of explanation. The implementation thus is NOT optimized for running efficiency. For recommended implementation, please refer to the official examples.

In this tutorial, you learn about a simplified implementation of the Transformer model. You can see highlights of the most important design points. For instance, there is only single-head attention. The complete code can be found here.

The overall structure is similar to the one from the research papaer Annotated Transformer.

The Transformer model, as a replacement of CNN/RNN architecture for sequence modeling, was introduced in the research paper: Attention is All You Need. It improved the state of the art for machine translation as well as natural language inference task (GPT). Recent work on pre-training Transformer with large scale corpus (BERT) supports that it is capable of learning high-quality semantic representation.

The interesting part of Transformer is its extensive employment of attention. The classic use of attention comes from machine translation model, where the output token attends to all input tokens.

Transformer additionally applies self-attention in both decoder and encoder. This process forces words relate to each other to combine together, irrespective of their positions in the sequence. This is different from RNN-based model, where words (in the source sentence) are combined along the chain, which is thought to be too constrained.

Attention layer of Transformer

In the attention layer of Transformer, for each node the module learns to assign weights on its in-coming edges. For node pair (i,j)(�,�) (from i� to j�) with node xi,xj∈Rn��,��∈��, the score of their connection is defined as follows:

qj=Wq⋅xjki=Wk⋅xivi=Wv⋅xiscore=qTjki��=��⋅����=��⋅����=��⋅��score=�����

where Wq,Wk,Wv∈Rn×dk��,��,��∈��×�� map the representations x� to “query”, “key”, and “value” space respectively.

There are other possibilities to implement the score function. The dot product measures the similarity of a given query qj�� and a key ki��: if j� needs the information stored in i�, the query vector at position j� (qj��) is supposed to be close to key vector at position i� (ki��).

The score is then used to compute the sum of the incoming values, normalized over the weights of edges, stored in wvwv. Then apply an affine layer to wvwv to get the output o�:

wji=exp{scoreji}∑(k,i)∈Eexp{scoreki}wvi=∑(k,i)∈Ewkivko=Wo⋅wv���=exp⁡{score��}∑(�,�)∈�exp⁡{score��}wv�=∑(�,�)∈�������=��⋅wv

Multi-head attention layer

In Transformer, attention is multi-headed. A head is very much like a channel in a convolutional network. The multi-head attention consists of multiple attention heads, in which each head refers to a single attention module. wv(i)wv(�) for all the heads are concatenated and mapped to output o� with an affine layer:

o=Wo⋅concat([wv(0),wv(1),⋯,wv(h)])�=��⋅concat([wv(0),wv(1),⋯,wv(ℎ)])

The code below wraps necessary components for multi-head attention, and provides two interfaces.

  • get maps state ‘x’, to query, key and value, which is required by following steps(propagate_attention).

  • get_o maps the updated value after attention to the output o� for post-processing.

class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"
    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
        self.linears = clones(nn.Linear(dim_model, dim_model), 4)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if 'k' in fields:
            ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if 'v' in fields:
            ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))

Copy to clipboard

How DGL implements Transformer with a graph neural network

You get a different perspective of Transformer by treating the attention as edges in a graph and adopt message passing on the edges to induce the appropriate processing.

Graph structure

Construct the graph by mapping tokens of the source and target sentence to nodes. The complete Transformer graph is made up of three subgraphs:

Source language graph. This is a complete graph, each token si�� can attend to any other token sj�� (including self-loops). 

image0

 Target language graph. The graph is half-complete, in that ti�� attends only to tj�� if i>j�>� (an output token can not depend on future words). 

image1

 Cross-language graph. This is a bi-partitie graph, where there is an edge from every source token si�� to every target token tj��, meaning every target token can attend on source tokens. 

image2

The full picture looks like this: 

image3

Pre-build the graphs in dataset preparation stage.

Message passing

Once you define the graph structure, move on to defining the computation for message passing.

Assuming that you have already computed all the queries qi��, keys ki�� and values vi��. For each node i� (no matter whether it is a source token or target token), you can decompose the attention computation into two steps:

  1. Message computation: Compute attention score scoreijscore�� between i� and all nodes j� to be attended over, by taking the scaled-dot product between qi�� and kj��. The message sent from j� to i� will consist of the score scoreijscore�� and the value vj��.

  2. Message aggregation: Aggregate the values vj�� from all j� according to the scores scoreijscore��.

Simple implementation
Message computation

Compute score and send source node’s v to destination’s mailbox

def message_func(edges):
    return {'score': ((edges.src['k'] * edges.dst['q'])
                      .sum(-1, keepdim=True)),
            'v': edges.src['v']}

Copy to clipboard

Message aggregation

Normalize over all in-edges and weighted sum to get output

import torch as th
import torch.nn.functional as F

def reduce_func(nodes, d_k=64):
    v = nodes.mailbox['v']
    att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
    return {'dx': (att * v).sum(1)}

Copy to clipboard

Execute on specific edges
import functools.partial as partial
def naive_propagate_attention(self, g, eids):
    g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))

Copy to clipboard

Speeding up with built-in functions

To speed up the message passing process, use DGL’s built-in functions, including:

  • fn.src_mul_egdes(src_field, edges_field, out_field) multiplies source’s attribute and edges attribute, and send the result to the destination node’s mailbox keyed by out_field.

  • fn.copy_edge(edges_field, out_field) copies edge’s attribute to destination node’s mailbox.

  • fn.sum(edges_field, out_field) sums up edge’s attribute and sends aggregation to destination node’s mailbox.

Here, you assemble those built-in functions into propagate_attention, which is also the main graph operation function in the final implementation. To accelerate it, break the softmax operation into the following steps. Recall that for each head there are two phases.

  1. Compute attention score by multiply src node’s k and dst node’s q

    • g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)

  2. Scaled Softmax over all dst nodes’ in-coming edges

    • Step 1: Exponentialize score with scale normalize constant

      • g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))

        scoreij←exp(scoreijdk−−√)score��←exp⁡(score����)

    • Step 2: Get the “values” on associated nodes weighted by “scores” on in-coming edges of each node; get the sum of “scores” on in-coming edges of each node for normalization. Note that here wvwv is not normalized.

      • msg: fn.src_mul_edge('v', 'score', 'v'), reduce: fn.sum('v', 'wv')

        wvj=∑i=1Nscoreij⋅viwv�=∑�=1�score��⋅��

      • msg: fn.copy_edge('score', 'score'), reduce: fn.sum('score', 'z')

        zj=∑i=1Nscoreijz�=∑�=1�score��

The normalization of wvwv is left to post processing.

def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


def propagate_attention(self, g, eids):
    # Compute attention score
    g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
    g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
    # Update node state
    g.send_and_recv(eids,
                    [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                    [fn.sum('v', 'wv'), fn.sum('score', 'z')])

Copy to clipboard

Preprocessing and postprocessing

In Transformer, data needs to be pre- and post-processed before and after the propagate_attention function.

Preprocessing The preprocessing function pre_func first normalizes the node representations and then map them to a set of queries, keys and values, using self-attention as an example:

x←LayerNorm(x)[q,k,v]←[Wq,Wk,Wv]⋅x�←LayerNorm(�)[�,�,�]←[��,��,��]⋅�

Postprocessing The postprocessing function post_funcs completes the whole computation correspond to one layer of the transformer: 1. Normalize wvwv and get the output of Multi-Head Attention Layer o�.

wv←wvzo←Wo⋅wv+bowv←wv��←��⋅wv+��

add residual connection:

x←x+o�←�+�

  1. Applying a two layer position-wise feed forward layer on x� then add residual connection:

    x←x+LayerNorm(FFN(x))�←�+LayerNorm(FFN(�))

    where FFNFFN refers to the feed forward function.

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func

    def post_func(self, i):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv', l=0):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            if fields == 'kv':
                norm_x = x # In enc-dec attention, x has already been normalized.
            else:
                norm_x = layer.sublayer[l].norm(x)
            return layer.self_attn.get(norm_x, fields)
        return func

    def post_func(self, i, l=0):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

Copy to clipboard

This completes all procedures of one layer of encoder and decoder in Transformer.

Note

The sublayer connection part is little bit different from the original paper. However, this implementation is the same as The Annotated Transformer and OpenNMT.

Main class of Transformer graph

The processing flow of Transformer can be seen as a 2-stage message-passing within the complete graph (adding pre- and post- processing appropriately): 1) self-attention in encoder, 2) self-attention in decoder followed by cross-attention between encoder and decoder, as shown below. 

image4

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
        super(Transformer, self).__init__()
        self.encoder, self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc = pos_enc
        self.generator = generator
        self.h, self.d_k = h, d_k

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."

        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        nids, eids = graph.nids, graph.eids

        # Word Embedding and Position Embedding
        src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
        tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
        g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
        g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)

        for i in range(self.encoder.N):
            # Step 1: Encoder Self-attention
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            nodes, edges = nids['enc'], eids['ee']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])

        for i in range(self.decoder.N):
            # Step 2: Dncoder Self-attention
            pre_func = self.decoder.pre_func(i, 'qkv')
            post_func = self.decoder.post_func(i)
            nodes, edges = nids['dec'], eids['dd']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            # Step 3: Encoder-Decoder attention
            pre_q = self.decoder.pre_func(i, 'q', 1)
            pre_kv = self.decoder.pre_func(i, 'kv', 1)
            post_func = self.decoder.post_func(i, 1)
            nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
            self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])

        return self.generator(g.ndata['x'][nids['dec']])

Copy to clipboard

Note

By calling update_graph function, you can create your own Transformer on any subgraphs with nearly the same code. This flexibility enables us to discover new, sparse structures (c.f. local attention mentioned here). Note in this implementation you don’t use mask or padding, which makes the logic more clear and saves memory. The trade-off is that the implementation is slower.

Training

This tutorial does not cover several other techniques such as Label Smoothing and Noam Optimizations mentioned in the original paper. For detailed description about these modules, read The Annotated Transformer written by Harvard NLP team.

Task and the dataset

The Transformer is a general framework for a variety of NLP tasks. This tutorial focuses on the sequence to sequence learning: it’s a typical case to illustrate how it works.

As for the dataset, there are two example tasks: copy and sort, together with two real-world translation tasks: multi30k en-de task and wmt14 en-de task.

  • copy dataset: copy input sequences to output. (train/valid/test: 9000, 1000, 1000)

  • sort dataset: sort input sequences as output. (train/valid/test: 9000, 1000, 1000)

  • Multi30k en-de, translate sentences from En to De. (train/valid/test: 29000, 1000, 1000)

  • WMT14 en-de, translate sentences from En to De. (Train/Valid/Test: 4500966/3000/3003)

Note

Training with wmt14 requires multi-GPU support and is not available. Contributions are welcome!

Graph building

Batching This is similar to the way you handle Tree-LSTM. Build a graph pool in advance, including all possible combination of input lengths and output lengths. Then for each sample in a batch, call dgl.batch to batch graphs of their sizes together in to a single large graph.

You can wrap the process of creating graph pool and building BatchedGraph in dataset.GraphPool and dataset.TranslationDataset.

graph_pool = GraphPool()

data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
    print(graph.nids['enc']) # encoder node ids
    print(graph.nids['dec']) # decoder node ids
    print(graph.eids['ee']) # encoder-encoder edge ids
    print(graph.eids['ed']) # encoder-decoder edge ids
    print(graph.eids['dd']) # decoder-decoder edge ids
    print(graph.src[0]) # Input word index list
    print(graph.src[1]) # Input positions
    print(graph.tgt[0]) # Output word index list
    print(graph.tgt[1]) # Ouptut positions
    break

Copy to clipboard

Output:

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')
tensor([ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
         95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
        109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
        123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
        137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
        151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
        165, 166, 167, 168, 169, 170], device='cuda:0')
tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
        199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
        213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],
       device='cuda:0')
tensor([28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 0, 28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')

Copy to clipboard

Put it all together

Train a one-head transformer with one layer, 128 dimension on copy task. Set other parameters to the default.

Inference module is not included in this tutorial. It requires beam search. For a full implementation, see the GitHub repo.

from tqdm import tqdm
import torch as th
import numpy as np

from loss import LabelSmoothing, SimpleLossCompute
from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool

def run_epoch(data_iter, model, loss_compute, is_train=True):
    for i, g in tqdm(enumerate(data_iter)):
        with th.set_grad_enabled(is_train):
            output = model(g)
            loss = loss_compute(output, g.tgt_y, g.n_tokens)
    print('average loss: {}'.format(loss_compute.avg_loss))
    print('accuracy: {}'.format(loss_compute.accuracy))

N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']

dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128

# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
                    th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute

att_maps = []
for epoch in range(4):
    train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
    valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
    print('Epoch: {} Training...'.format(epoch))
    model.train(True)
    run_epoch(train_iter, model,
              loss_compute(criterion, model_opt), is_train=True)
    print('Epoch: {} Evaluating...'.format(epoch))
    model.att_weight_map = None
    model.eval()
    run_epoch(valid_iter, model,
              loss_compute(criterion, None), is_train=False)
    att_maps.append(model.att_weight_map)

Copy to clipboard

Visualization

After training, you can visualize the attention that the Transformer generates on copy task.

src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
# visualize head 0 of encoder-decoder attention
att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)

Copy to clipboard

image5

 from the figure you see the decoder nodes gradually learns to attend to corresponding nodes in input sequence, which is the expected behavior.

Multi-head attention

Besides the attention of a one-head attention trained on toy task. We also visualize the attention scores of Encoder’s Self Attention, Decoder’s Self Attention and the Encoder-Decoder attention of an one-Layer Transformer network trained on multi-30k dataset.

From the visualization you see the diversity of different heads, which is what you would expect. Different heads learn different relations between word pairs.

  • Encoder Self-Attention 

    image6

  • Encoder-Decoder Attention Most words in target sequence attend on their related words in source sequence, for example: when generating “See” (in De), several heads attend on “lake”; when generating “Eisfischerhütte”, several heads attend on “ice”. 

    image7

  • Decoder Self-Attention Most words attend on their previous few words. 

    image8

Adaptive Universal Transformer

A recent research paper by Google, Universal Transformer, is an example to show how update_graph adapts to more complex updating rules.

The Universal Transformer was proposed to address the problem that vanilla Transformer is not computationally universal by introducing recurrence in Transformer:

  • The basic idea of Universal Transformer is to repeatedly revise its representations of all symbols in the sequence with each recurrent step by applying a Transformer layer on the representations.

  • Compared to vanilla Transformer, Universal Transformer shares weights among its layers, and it does not fix the recurrence time (which means the number of layers in Transformer).

A further optimization employs an adaptive computation time (ACT) mechanism to allow the model to dynamically adjust the number of times the representation of each position in a sequence is revised (refereed to as step hereafter). This model is also known as the Adaptive Universal Transformer (AUT).

In AUT, you maintain an active nodes list. In each step t�, we compute a halting probability: h(0<h<1)ℎ(0<ℎ<1) for all nodes in this list by:

hti=σ(Whxti+bh)ℎ��=�(�ℎ���+�ℎ)

then dynamically decide which nodes are still active. A node is halted at time T� if and only if ∑T−1t=1ht<1−ε≤∑Tt=1ht∑�=1�−1ℎ�<1−�≤∑�=1�ℎ�. Halted nodes are removed from the list. The procedure proceeds until the list is empty or a pre-defined maximum step is reached. From DGL’s perspective, this means that the “active” graph becomes sparser over time.

The final state of a node si�� is a weighted average of xti��� by htiℎ��:

si=∑t=1Thti⋅xti��=∑�=1�ℎ��⋅���

In DGL, implement an algorithm by calling update_graph on nodes that are still active and edges associated with this nodes. The following code shows the Universal Transformer class in DGL:

class UTransformer(nn.Module):
    "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
    MAX_DEPTH = 8
    thres = 0.99
    act_loss_weight = 0.01
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
        super(UTransformer, self).__init__()
        self.encoder,  self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc, self.time_enc = pos_enc, time_enc
        self.halt_enc = HaltingUnit(h * d_k)
        self.halt_dec = HaltingUnit(h * d_k)
        self.generator = generator
        self.h, self.d_k = h, d_k

    def step_forward(self, nodes):
        # add positional encoding and time encoding, increment step by one
        x = nodes.data['x']
        step = nodes.data['step']
        pos = nodes.data['pos']
        return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
                'step': step + 1}

    def halt_and_accum(self, name, end=False):
        "field: 'enc' or 'dec'"
        halt = self.halt_enc if name == 'enc' else self.halt_dec
        thres = self.thres
        def func(nodes):
            p = halt(nodes.data['x'])
            sum_p = nodes.data['sum_p'] + p
            active = (sum_p < thres) & (1 - end)
            _continue = active.float()
            r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
            s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
            return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
        return func

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."
        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
        g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
        g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
        g.nodes[nids['enc']].data['pos'] = graph.src[1]
        g.nodes[nids['dec']].data['pos'] = graph.tgt[1]

        # init step
        device = next(self.parameters()).device
        g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)    # accumulated state
        g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)                    # halting prob
        g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)                     # remainder
        g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)                # sum of pondering values
        g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)                  # step
        g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)                # active

        for step in range(self.MAX_DEPTH):
            pre_func = self.encoder.pre_func('qkv')
            post_func = self.encoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])

        g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])

        for step in range(self.MAX_DEPTH):
            pre_func = self.decoder.pre_func('qkv')
            post_func = self.decoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes)])

            pre_q = self.decoder.pre_func('q', 1)
            pre_kv = self.decoder.pre_func('kv', 1)
            post_func = self.decoder.post_func(1)
            nodes_e = nids['enc']
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(pre_q, nodes), (pre_kv, nodes_e)],
                              [(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])

        g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
        act_loss = th.mean(g.ndata['r']) # ACT loss

        return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight

Copy to clipboard

Call filter_nodes and filter_edge to find nodes/edges that are still active:

Note

  • filter_nodes() takes a predicate and a node ID list/tensor as input, then returns a tensor of node IDs that satisfy the given predicate.

  • filter_edges() takes a predicate and an edge ID list/tensor as input, then returns a tensor of edge IDs that satisfy the given predicate.

For the full implementation, see the GitHub repo.

The figure below shows the effect of Adaptive Computational Time. Different positions of a sentence were revised different times.

image9

You can also visualize the dynamics of step distribution on nodes during the training of AUT on sort task(reach 99.7% accuracy), which demonstrates how AUT learns to reduce recurrence steps during training. 

image10

Note

The notebook itself is not executable due to many dependencies. Download 7_transformer.py, and copy the python script to directory examples/pytorch/transformer then run python 7_transformer.py to see how it works.

Total running time of the script: ( 0 minutes 0.000 seconds)

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号