图卷积网络(Graph Convolutional Network,GCN)是近年来逐渐流行的一种神经网络结构,它能够处理具有广义拓扑图结构的数据,并深入发掘其特征和规律。图卷积网络的关键是学习一个函数,该函数通过聚合节点自身的特征及其邻居节点的特征来生成节点的表示。图卷积网络的应用非常广泛,包括但不限于以下几个方面:
图卷积神经网络可以被划分为谱域图卷积和空域图卷积。在本文中,我们主要介绍空域图卷积。空域图卷积将卷积操作直接定义在每个节点的连接关系上,通过节点的连接关系来实现节点间信息的传播与更新。通常,一个图可以被表征为, 其中X表示节点特征集,
其中,X为经过一次图卷积运算后的节点特征集,为激活函数, W为可学习参数。值得注意的是,每次图卷积运算均会聚合节点1-hop邻居的特征,通过叠加n次图卷积运算可以获取n-hop邻居节点的信息。
- import torch
- import torch.nn as nn
- from torch.nn.parameter import Parameter
- import torch.nn.init as init
- import math
- import numpy as np
- import torch.nn.functional as F
- class GraphConvolution(nn.Module):
- def __init__(self, input_dim: int, output_dim: int, bias: bool = True, device=None, dtype=None) -> None:
- """
- 空域图卷积:
- :param input_dim 输入单节点特征数
- :param output_dim 输出单节点特征数
- :param use_bias 是否使用偏置
- """
- super(GraphConvolution, self).__init__()
- factory_kwargs = {'device': device, 'dtype': dtype}
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.use_bias = bias
- self.weight = Parameter(torch.empty((input_dim, output_dim), **factory_kwargs))
- if bias:
- self.bias = Parameter(torch.empty(output_dim, **factory_kwargs))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
- # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
- # https://github.com/pytorch/pytorch/issues/57109
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
- if self.bias is not None:
- fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- init.uniform_(self.bias, -bound, bound)
- def forward(self, input_feature, adjacency):
- """
- adjacency: torch.FloatTensor 邻接矩阵 [nodes, nodes]
- input_feature: torch.Tensor 输入节点集合: [batch, nodes, features]
- process:
- H(l+1) = D^(-1/2)*A*D^(-1/2)*H(l)*W
- """
- adjacency = self.standard_adjacency(adjacency) # (8, 18, 8, 8)
- support = torch.einsum("bij,bjk->bik", [adjacency, input_feature])
- output = torch.matmul(support, self.weight)
- if self.use_bias:
- output += self.bias
- return output
- def standard_adjacency(self, adjacency): # 注意这里的邻接矩阵应带有自环
- """
- :param adjacency: 邻接矩阵 [batch, nodes, nodes]
- :return: 标准化邻接矩阵: [batch, nodes, nodes]
- """
- degree_matrix = torch.sum(adjacency, dim=-1, keepdim=False) # [8, 18, 8]
- degree_matrix = degree_matrix.pow(-0.5)
- degree_matrix[degree_matrix == float("inf")] = 0. # [64, 18]
- degree_matrix = degree_matrix.reshape(-1, len(degree_matrix[0]), 1) * torch.eye(
- len(degree_matrix[0]), len(degree_matrix[0])) # (16, 8, 8)
- return torch.matmul(torch.matmul(degree_matrix, 1.0 * adjacency), degree_matrix) # [64, 18, 18]
- if __name__ == "__main__":
- graph_conv = GraphConvolution(32, 64)
- # 这是我们随意设置的一个邻接矩阵,16 表示batch
- adjacency = torch.zeros(16, 8, 8) + torch.Tensor(np.eye(8, k=-1) + np.eye(8) + np.eye(8, k=1))
- x = torch.randn(16, 8, 32) # 输入一个batch 的节点特征
- h = F.relu(graph_conv(x, adjacency)) # 输出更新后的节点特征, 并使用relu作为激活函数
- print(h.shape)
- print(h)
