当前位置:   article > 正文

门控图神经网络(GGNN)及代码分析

ggnn

门控图神经网络GGNN及代码分析

基本概念

GGNN是一种基于GRU的经典的空间域message passing的模型

问题描述

一个图 G = (V, E), 节点v ∈ V中存储D维向量,边e ∈ E中存储D × D维矩阵, 目的是构建网络GGNN。
实现每一次参数更新时,每个节点既接受相邻节点的信息,又向相邻节点发送信息。

主要贡献

基于GRU提出了GGNN,利用RNN类似原理实现了信息在graph中的传递。
在这里插入图片描述

传播模型

在这里插入图片描述
在这里插入图片描述

输出模型

在这里插入图片描述

来源论文
Gated Graph Sequence Neural Networks,ICLR 2016
链接:https://arxiv.org/abs/1511.05493
官方实现(Lua):https://github.com/yujiali/ggnn
第三方实现(pytorch):https://github.com/calebmah/ggnn.pytorch
GRU概念详见:https://blog.csdn.net/lthirdonel/article/details/88945257

代码分析

只看论文确实让人觉得玄学,特别是annotation部分,很迷
结合代码来看就好很多,这里例举的是@JamesChuanggg的pytorch实现ggnn.pytorch,这个实现的代码相比于官方版本来说,容易读很多

1.annotation
annotation = np.zeros([n_nodes, n_annotation_dim])
annotation[target[1]-1][0] = 1	
  • 1
  • 2

核心实现就是上面这个,除了表达到达关系部分用了1,其他padding成了0

2.每一个时间步的实现
class Propogator(nn.Module):
    """
    Gated Propogator for GGNN
    Using LSTM gating mechanism
    """
    def __init__(self, state_dim, n_node, n_edge_types):
        ## 初始化参照源代码
    def forward(self, state_in, state_out, state_cur, A):
        # 入边向量和出边向量
        A_in = A[:, :, :self.n_node*self.n_edge_types]
        A_out = A[:, :, self.n_node*self.n_edge_types:]

        # 入边向量和出边向量分别和图做计算
        a_in = torch.bmm(A_in, state_in)
        a_out = torch.bmm(A_out, state_out)
        a = torch.cat((a_in, a_out, state_cur), 2)

        # 类GRU部分
        r = self.reset_gate(a)
        z = self.update_gate(a)
        joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
        h_hat = self.tansform(joined_input)

        output = (1 - z) * state_cur + z * h_hat

        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
3.网络结构
class GGNN(nn.Module):
    """
    Gated Graph Sequence Neural Networks (GGNN)
    Mode: SelectNode
    Implementation based on https://arxiv.org/abs/1511.05493
    """
    def __init__(self, opt):
        # 初始化参考源代码
    def forward(self, prop_state, annotation, A):
        # prop_state:论文中的h
        # annotation:节点标注
        # A:图
        for i_step in range(self.n_steps):
            # 对于每一个时间步循环
            in_states = []
            out_states = []
            for i in range(self.n_edge_types):
                # 对输入特征做两个分支的全连接,得到入边特征,和出边特征
                # 每一种边都要计算一次
                in_states.append(self.in_fcs[i](prop_state))
                out_states.append(self.out_fcs[i](prop_state))
            # 将所有种类的边得到的特征连接起来
            in_states = torch.stack(in_states).transpose(0, 1).contiguous()
            in_states = in_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)
            out_states = torch.stack(out_states).transpose(0, 1).contiguous()
            out_states = out_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)

            # 用门控图模块更新h
            prop_state = self.propogator(in_states, out_states, prop_state, A)

        join_state = torch.cat((prop_state, annotation), 2)

        output = self.out(join_state)
        output = output.sum(2)

        return output

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/750277
推荐阅读
  

闽ICP备14008679号