赞
踩
更多图神经网络和深度学习内容请关注:
DGL NN模块
是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。 在DGL NN模块中,构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。 DGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。
DGL已经集成了很多常用的 Conv Layers、 Dense Conv Layers、 Global Pooling Layers 和 Utility Modules。欢迎给DGL贡献更多的模块!
本章将使用PyTorch作为后端,用 SAGEConv
作为例子来介绍如何构建用户自己的DGL NN模块。
构造函数__init__
完成以下几个任务:
import torch.nn as nn from dgl.utils import expand_as_pair class SAGEConv(nn.Module): def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)#函数可以返回一个二维元组。 self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.activation = activation
Using backend: pytorch
在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。
除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type
)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 mean
、 sum
、 max
和 min
。一些模块可能会使用更加复杂的聚合函数,比如 lstm
。
上面代码里的 norm
是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
h
v
=
h
v
/
∥
h
v
∥
2
h_v = h_v / \lVert h_v \rVert_2
hv=hv/∥hv∥2
# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear
、 nn.LSTM
等。 构造函数的最后调用了 reset_parameters()
进行权重初始化。
def reset_parameters(self):
"""重新初始化可学习的参数"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
完整代码
import torch.nn as nn from dgl.utils import expand_as_pair class SAGEConv(nn.Module): def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.activation = activation # 聚合类型:mean、max_pool、lstm、gcn if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']: raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) if aggregator_type == 'max_pool': self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) if aggregator_type == 'lstm': self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) if aggregator_type in ['mean', 'max_pool', 'lstm']: self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) self.reset_parameters() def reset_parameters(self): """重新初始化可学习的参数""" gain = nn.init.calculate_gain('relu') if self._aggre_type == 'max_pool': nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) if self._aggre_type == 'lstm': self.lstm.reset_parameters() if self._aggre_type != 'gcn': nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
在NN模块中, forward()
函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 dgl.DGLGraph
。forward()
函数的内容一般可以分为3项操作:
下文展示了SAGEConv示例中的 forward()
函数。
def forward(self, graph, feat):
with graph.local_scope():
# 指定图类型,然后根据图类型扩展输入特征
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.local_scope()
:限定语句块内为局部作用域,对数据特征的操作不影响原始图的特征,常用于forward方法中,
用法:
def foo(g):
with g.local_scope():
g.edata['h'] = torch.ones((g.num_edges(), 3))
g.edata['h2'] = torch.ones((g.num_edges(), 3))
return g.edata['h']
in-place
操作会影响原始图数据
如:
def foo(g):
with g.local_scope():
# in-place operation
g.edata['h'] += 1
return g.edata['h']
参考dgl API
forward()
函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 比如在 GraphConv
等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入度为0时, mailbox
将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 SAGEConv
模块中,被聚合的特征将会与节点的初始特征拼接起来, forward()
函数的输出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。
SAGEConv的数学公式如下:
h N ( d s t ) ( l + 1 ) = a g g r e g a t e ( { h s r c l , ∀ s r c ∈ N ( d s t ) } ) h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right) hN(dst)(l+1)=aggregate({hsrcl,∀src∈N(dst)})
h d s t ( l + 1 ) = σ ( W ⋅ c o n c a t ( h d s t l , h N ( d s t ) l + 1 ) + b ) h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1}) + b \right) hdst(l+1)=σ(W⋅concat(hdstl,hN(dst)l+1)+b)
h d s t ( l + 1 ) = n o r m ( h d s t l + 1 ) h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l+1}) hdst(l+1)=norm(hdstl+1)
源节点特征 feat_src
和目标节点特征 feat_dst
需要根据图类型被指定。 用于指定图类型并将 feat
扩展为 feat_src
和 feat_dst
的函数是 expand_as_pair()
。 该函数的细节如下所示。
def expand_as_pair(input_, g=None): if isinstance(input_, tuple): # 二分图的情况 return input_ elif g is not None and g.is_block: # 子图块的情况 if isinstance(input_, Mapping): input_dst = { k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) for k, v in input_.items()} else: input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes()) return input_, input_dst else: # 同构图的情况 return input_, input_
对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。
在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)
。 当输入特征 feat
是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。
在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block
)。 在区块创建的阶段,dst nodes
位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()]
可以找到 feat_dst
。
确定 feat_src
和 feat_dst
之后,以上3种图类型的计算方法是相同的。
def forward(self, graph, feat): with graph.local_scope(): import dgl.function as fn import torch.nn.functional as F from dgl.utils import check_eq_shape # 指定图类型,然后根据图类型扩展输入特征 feat_src, feat_dst = expand_as_pair(feat, graph) if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) # 除以入度 degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'max_pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE中gcn聚合不需要fc_self if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。
# 激活函数
if self.activation is not None:
rst = self.activation(rst)
# 归一化
if self.norm is not None:
rst = self.norm(rst)
return rst
forward()
函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
完整代码
import torch.nn as nn from dgl.utils import expand_as_pair class SAGEConv(nn.Module): def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.activation = activation # 聚合类型:mean、max_pool、lstm、gcn if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']: raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) if aggregator_type == 'max_pool': self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) if aggregator_type == 'lstm': self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) if aggregator_type in ['mean', 'max_pool', 'lstm']: self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) self.reset_parameters() def reset_parameters(self): """重新初始化可学习的参数""" gain = nn.init.calculate_gain('relu') if self._aggre_type == 'max_pool': nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) if self._aggre_type == 'lstm': self.lstm.reset_parameters() if self._aggre_type != 'gcn': nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) def forward(self, graph, feat): with graph.local_scope(): import dgl.function as fn import torch.nn.functional as F from dgl.utils import check_eq_shape # 指定图类型,然后根据图类型扩展输入特征 feat_src, feat_dst = expand_as_pair(feat, graph) if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) # 除以入度 degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'max_pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE中gcn聚合不需要fc_self if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # 激活函数 if self.activation is not None: rst = self.activation(rst) # 归一化 if self.norm is not None: rst = self.norm(rst) return rst
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。