赞
踩
GGNN是一种基于GRU的经典的空间域message passing的模型
一个图 G = (V, E), 节点v ∈ V中存储D维向量,边e ∈ E中存储D × D维矩阵, 目的是构建网络GGNN。
实现每一次参数更新时,每个节点既接受相邻节点的信息,又向相邻节点发送信息。
基于GRU提出了GGNN,利用RNN类似原理实现了信息在graph中的传递。
来源论文
Gated Graph Sequence Neural Networks,ICLR 2016
链接:https://arxiv.org/abs/1511.05493
官方实现(Lua):https://github.com/yujiali/ggnn
第三方实现(pytorch):https://github.com/calebmah/ggnn.pytorch
GRU概念详见:https://blog.csdn.net/lthirdonel/article/details/88945257
只看论文确实让人觉得玄学,特别是annotation部分,很迷
结合代码来看就好很多,这里例举的是@JamesChuanggg的pytorch实现ggnn.pytorch,这个实现的代码相比于官方版本来说,容易读很多
annotation = np.zeros([n_nodes, n_annotation_dim])
annotation[target[1]-1][0] = 1
核心实现就是上面这个,除了表达到达关系部分用了1,其他padding成了0
class Propogator(nn.Module): """ Gated Propogator for GGNN Using LSTM gating mechanism """ def __init__(self, state_dim, n_node, n_edge_types): ## 初始化参照源代码 def forward(self, state_in, state_out, state_cur, A): # 入边向量和出边向量 A_in = A[:, :, :self.n_node*self.n_edge_types] A_out = A[:, :, self.n_node*self.n_edge_types:] # 入边向量和出边向量分别和图做计算 a_in = torch.bmm(A_in, state_in) a_out = torch.bmm(A_out, state_out) a = torch.cat((a_in, a_out, state_cur), 2) # 类GRU部分 r = self.reset_gate(a) z = self.update_gate(a) joined_input = torch.cat((a_in, a_out, r * state_cur), 2) h_hat = self.tansform(joined_input) output = (1 - z) * state_cur + z * h_hat return output
class GGNN(nn.Module): """ Gated Graph Sequence Neural Networks (GGNN) Mode: SelectNode Implementation based on https://arxiv.org/abs/1511.05493 """ def __init__(self, opt): # 初始化参考源代码 def forward(self, prop_state, annotation, A): # prop_state:论文中的h # annotation:节点标注 # A:图 for i_step in range(self.n_steps): # 对于每一个时间步循环 in_states = [] out_states = [] for i in range(self.n_edge_types): # 对输入特征做两个分支的全连接,得到入边特征,和出边特征 # 每一种边都要计算一次 in_states.append(self.in_fcs[i](prop_state)) out_states.append(self.out_fcs[i](prop_state)) # 将所有种类的边得到的特征连接起来 in_states = torch.stack(in_states).transpose(0, 1).contiguous() in_states = in_states.view(-1, self.n_node*self.n_edge_types, self.state_dim) out_states = torch.stack(out_states).transpose(0, 1).contiguous() out_states = out_states.view(-1, self.n_node*self.n_edge_types, self.state_dim) # 用门控图模块更新h prop_state = self.propogator(in_states, out_states, prop_state, A) join_state = torch.cat((prop_state, annotation), 2) output = self.out(join_state) output = output.sum(2) return output
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。