当前位置:   article > 正文

【图神经网络】图注意力机制GAT以及Pytorch实现

图注意力机制

1 原理介绍

  GAT(Graph Attention Networks)图注意力网络的原理相对好理解,相比较GCN而言就是对汇聚到中心节点的邻居节点学习了一个权重,使其能够按照权重进行邻域特征的加和。下面列出的参考文献都给出了详细的原理介绍,这里只阐述重点。

1.1 计算注意力系数

α i , j = exp ⁡ ( α ( a T [ W x i ∥ W x j ] ) ) ∑ k ∈ N ( i ) ∪ i exp ⁡ ( α ( a T [ W x i ∥ W x j ] ) ) \alpha_{i, j}=\frac{\exp \left(\alpha\left(a^{T}\left[W x_{i} \| W x_{j}\right]\right)\right)}{\sum_{k \in N(i) \cup i} \exp \left(\alpha\left(a^{T}\left[W x_{i} \| Wx_{j}\right]\right)\right)} αi,j=kN(i)iexp(α(aT[WxiWxj]))exp(α(aT[WxiWxj]))
  其中的 α \alpha α 代表注意力分数, W W W代表可学习参数, x j x_j xj代表邻居节点的特征向量。
  解读一下这个公式:首先,一个共享参数 W W W 的线性映射对于顶点的特征进行了增维,当然这是一种常见的特征增强(feature augment)方法。 [ ⋅ ∣ ∣ ⋅ ] [\sdot||\sdot] [∣∣]表示对于顶点 i , j i, j i,j的变换后的特征进行了拼接。最后, α \alpha α把拼接后的高维特征映射到一个实数上。显然学习顶点 i , j i, j i,j之间的相关性,就是通过可学习的参数 W W W和映射 a a a完成的。有了相关系数,离注意力系数就差归一化了,其实就是用个 s o f t m a x softmax softmax,如上式所示。

1.2 加权求和

  第二步根据计算好的注意力系数,把特征加权求和一下。
h i ′ = σ ( ∑ j ∈ N i α i j W h j ) h_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} W h_{j}\right) hi=σ jNiαijWhj
h i ′ h_{i}^{\prime} hi就是GAT输出的对于每个顶点 i i i的新特征(融合了邻域信息), σ ( ⋅ ) \sigma(\sdot) σ()是激活函数。

1.3 多头注意力机制

  multi-head attention也可以理解成用了ensemble的方法,因为衡量相似度的方法不同多用几个头(可以理解多用几种相似性度量方法)。
h i ′ ( K ) = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k h j ) h_{i}^{\prime}(K)=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right){\tiny {\scriptsize } } hi(K)=k=1Kσ jNiαijkWkhj

2 代码实现

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 
    图注意力层
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features   # 节点表示向量的输入特征维度
        self.out_features = out_features   # 节点表示向量的输出特征维度
        self.dropout = dropout    # dropout参数
        self.alpha = alpha     # leakyrelu激活的参数
        self.concat = concat   # 如果为true, 再进行elu激活
        
        # 定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  
        nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)   # xavier初始化
        
        # 定义leakyrelu激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, inp, adj):
        """
        inp: input_fea [N, in_features]  in_features表示节点的输入特征向量元素个数
        adj: 图的邻接矩阵 维度[N, N] 非零即一,数据结构基本知识
        """
        h = torch.mm(inp, self.W)   # [N, out_features]
        N = h.size()[0]    # N 图的节点数
        
        a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)
        # [N, N, 2*out_features]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        # [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
        
        zero_vec = -1e12 * torch.ones_like(e)    # 将没有连接的边置为负无穷
        attention = torch.where(adj>0, e, zero_vec)   # [N, N]
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留,
        # 否则需要mask并置为非常小的值,原因是softmax的时候这个最小值会不考虑。
        attention = F.softmax(attention, dim=1)    # softmax形状保持不变 [N, N],得到归一化的注意力权重!
        attention = F.dropout(attention, self.dropout, training=self.training)   # dropout,防止过拟合
        h_prime = torch.matmul(attention, h)  # [N, N].[N, out_features] => [N, out_features]
        # 得到由周围节点通过注意力权重进行更新的表示
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime 
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

加入多头注意力机制

class GAT(nn.Module):
    def __init__(self, n_feat, n_hid, n_class, dropout, alpha, n_heads):
        """Dense version of GAT
        n_heads 表示有几个GAL层,最后进行拼接在一起,类似self-attention
        从不同的子空间进行抽取特征。
        """
        super(GAT, self).__init__()
        self.dropout = dropout 
        
        # 定义multi-head的图注意力层
        self.attentions = [GraphAttentionLayer(n_feat, n_hid, dropout=dropout, alpha=alpha, concat=True) for _ in range(n_heads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)   # 加入pytorch的Module模块
        # 输出层,也通过图注意力层来实现,可实现分类、预测等功能
        self.out_att = GraphAttentionLayer(n_hid * n_heads, n_class, dropout=dropout,alpha=alpha, concat=False)
    
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)   # dropout,防止过拟合
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)  # 将每个head得到的表示进行拼接
        x = F.dropout(x, self.dropout, training=self.training)   # dropout,防止过拟合
        x = F.elu(self.out_att(x, adj))   # 输出并激活
        return F.log_softmax(x, dim=1)  # log_softmax速度变快,保持数值稳定
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

3 深入理解GAT

来自参考文献【9】

3.1 为什么GAT适用于inductive任务

  GAT中重要的学习参数是 W W W a ( ⋅ ) a(\sdot) a(),因为上述的逐顶点运算方式,这两个参数仅与顶点特征相关与图的结构毫无关系(也可以从代码上看出来)。所以测试任务中改变图的结构,对于GAT影响并不大只需要改变 N i N_i Ni重新计算即可。与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。学习的参数很大程度与图结构相关,这使得GCN在inductive任务上遇到困境。所以对于图结构存在噪声的任务来讲GAT会比GNN好。

3.2 与GCN的联系

  可以发现本质上而言:GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算),利用graph上的local stationary学习新的顶点特征表达。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为 顶点特征之间的相关性被更好地融入到模型中。

4 参考文献

[1]【GNN】GAT:Attention 在 GNN 中的应用
[2]图注意力网络(GAT) ICLR2018, Graph Attention Network论文详解
[3]https://github.com/dmlc/dgl/tree/master/examples/pytorch/gat
[4]https://github.com/PetarV-/GAT
[5]【图表示学习】pytorch实现图注意力网络GAT
[6]【图结构】之图注意力网络GAT详解
[7]Graph Attention Networks (GAT)pytorch源码解读
[8]Pytorch实现GAT(基于Message Passing消息传递机制实现)
[9]向往的GAT(图注意力网络的原理、实现及计算复杂度)
[10]通过pytorch深入理解图注意力网络(GAT)
[11]Pytorch实现GAT(基于PyTorch实现)
[12]GAT图注意力网络论文源码pytorch版超详细注释讲解!!!
[13]代码推荐这个

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号