当前位置:   article > 正文

图神经网络框架DGL教程-第3章:构建图神经网络(GNN)模块_dgl keras

dgl keras

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

第3章:构建图神经网络(GNN)模块

DGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。 在DGL NN模块中,构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。 DGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。

DGL已经集成了很多常用的 Conv LayersDense Conv LayersGlobal Pooling LayersUtility Modules。欢迎给DGL贡献更多的模块!

本章将使用PyTorch作为后端,用 SAGEConv 作为例子来介绍如何构建用户自己的DGL NN模块。

3.1 DGL NN模块的构造函数

构造函数__init__完成以下几个任务:

  • 设置选项。
  • 注册可学习的参数或者子模块。
  • 初始化参数。
import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)#函数可以返回一个二维元组。
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
Using backend: pytorch
  • 1

在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 对于图神经网络输入维度可被分为源节点特征维度和目标节点特征维度

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: h v = h v / ∥ h v ∥ 2 h_v = h_v / \lVert h_v \rVert_2 hv=hv/hv2

# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
    raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
    self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
    self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
    self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。 构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):
    """重新初始化可学习的参数"""
    gain = nn.init.calculate_gain('relu')
    if self._aggre_type == 'max_pool':
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
    if self._aggre_type == 'lstm':
        self.lstm.reset_parameters()
    if self._aggre_type != 'gcn':
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
'
运行

完整代码

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        
        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

3.2 编写DGL NN模块的forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

  • 检测输入图对象是否符合规范。
  • 消息传递和聚合。
  • 聚合后,更新特征作为输出。

下文展示了SAGEConv示例中的 forward() 函数。

输入图对象的规范检测

def forward(self, graph, feat):
    with graph.local_scope():
        # 指定图类型,然后根据图类型扩展输入特征
        feat_src, feat_dst = expand_as_pair(feat, graph)
  • 1
  • 2
  • 3
  • 4
'
运行

graph.local_scope():限定语句块内为局部作用域,对数据特征的操作不影响原始图的特征,常用于forward方法中,
用法:

def foo(g):
    with g.local_scope():
        g.edata['h'] = torch.ones((g.num_edges(), 3))
        g.edata['h2'] = torch.ones((g.num_edges(), 3))
        return g.edata['h']
  • 1
  • 2
  • 3
  • 4
  • 5
'
运行

in-place操作会影响原始图数据
如:

def foo(g):
    with g.local_scope():
        # in-place operation
        g.edata['h'] += 1
        return g.edata['h']
  • 1
  • 2
  • 3
  • 4
  • 5
'
运行

参考dgl API

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。

SAGEConv的数学公式如下:

h N ( d s t ) ( l + 1 ) = a g g r e g a t e ( { h s r c l , ∀ s r c ∈ N ( d s t ) } ) h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right) hN(dst)(l+1)=aggregate({hsrcl,srcN(dst)})

h d s t ( l + 1 ) = σ ( W ⋅ c o n c a t ( h d s t l , h N ( d s t ) l + 1 ) + b ) h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1}) + b \right) hdst(l+1)=σ(Wconcat(hdstl,hN(dst)l+1)+b)

h d s t ( l + 1 ) = n o r m ( h d s t l + 1 ) h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l+1}) hdst(l+1)=norm(hdstl+1)

源节点特征 feat_src 和目标节点特征 feat_dst 需要根据图类型被指定。 用于指定图类型并将 feat 扩展为 feat_srcfeat_dst 的函数是 expand_as_pair()。 该函数的细节如下所示。

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # 二分图的情况
        return input_
    elif g is not None and g.is_block:
        # 子图块的情况
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # 同构图的情况
        return input_, input_
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
'
运行

对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。

在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)。 当输入特征 feat 是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。

在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block)。 在区块创建的阶段,dst nodes 位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

确定 feat_srcfeat_dst 之后,以上3种图类型的计算方法是相同的。

消息传递和聚合

def forward(self, graph, feat):
    with graph.local_scope():
        import dgl.function as fn
        import torch.nn.functional as F
        from dgl.utils import check_eq_shape
        
        # 指定图类型,然后根据图类型扩展输入特征
        feat_src, feat_dst = expand_as_pair(feat, graph)
        
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'max_pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
'
运行

上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

聚合后,更新特征作为输出

# 激活函数
if self.activation is not None:
    rst = self.activation(rst)
# 归一化
if self.norm is not None:
    rst = self.norm(rst)
return rst
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

完整代码

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        
        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
        
    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
        
	def forward(self, graph, feat):
	    with graph.local_scope():
	        import dgl.function as fn
	        import torch.nn.functional as F
	        from dgl.utils import check_eq_shape
	        
	        # 指定图类型,然后根据图类型扩展输入特征
	        feat_src, feat_dst = expand_as_pair(feat, graph)
	        
	        if self._aggre_type == 'mean':
	            graph.srcdata['h'] = feat_src
	            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
	            h_neigh = graph.dstdata['neigh']
	        elif self._aggre_type == 'gcn':
	            check_eq_shape(feat)
	            graph.srcdata['h'] = feat_src
	            graph.dstdata['h'] = feat_dst
	            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
	            # 除以入度
	            degs = graph.in_degrees().to(feat_dst)
	            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
	        elif self._aggre_type == 'max_pool':
	            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
	            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
	            h_neigh = graph.dstdata['neigh']
	        else:
	            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
	
	        # GraphSAGE中gcn聚合不需要fc_self
	        if self._aggre_type == 'gcn':
	            rst = self.fc_neigh(h_neigh)
	        else:
	            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
	            
	        # 激活函数
	        if self.activation is not None:
	            rst = self.activation(rst)
	        # 归一化
	        if self.norm is not None:
	            rst = self.norm(rst)
	        return rst
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84

3.3 异构图上的GraphConv模块

异构图上的GraphConv模块

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
  

闽ICP备14008679号