当前位置:   article > 正文

如何使用pytorch写一个图卷积神经网络(GCN)_pytorch gcn

pytorch gcn

1. 简介

(前段时间一直在用图卷积网络做信号处理与识别任务,所以抽了点时间给大家分享一下图卷积的基本计算和实现代码,为更好地方便大家理解和学习,代码部分已经设置了案例并且可以直接运行,喜欢的可以收藏,但请注意本文是禁止转载的哈)

       图卷积网络(Graph Convolutional Network,GCN)是近年来逐渐流行的一种神经网络结构,它能够处理具有广义拓扑图结构的数据,并深入发掘其特征和规律。图卷积网络的关键是学习一个函数,该函数通过聚合节点自身的特征及其邻居节点的特征来生成节点的表示。图卷积网络的应用非常广泛,包括但不限于以下几个方面:

  1. 社交网络分析:图卷积网络可以应用于社交网络分析,用于识别社区结构、推荐好友、分析信息传播等。例如,通过构建用户交互图,图卷积网络可以学习用户的兴趣和偏好,从而为用户推荐相似兴趣的好友或内容。
  2. 生物信息学:在生物信息学领域,图卷积网络可以用于蛋白质结构分析、药物发现等。通过分析分子结构的图表示,图卷积网络可以预测蛋白质的功能和药物的作用机制。
  3. 化学和材料科学:图卷积网络可以用于分子性质预测、新材料设计等。通过学习分子的图结构,图卷积网络可以预测其化学性质,为新材料的设计和合成提供有力支持。
  4. 交通预测:在智能交通系统中,图卷积网络可以用于预测交通流量、交通拥堵情况等。通过构建交通网络图,图卷积网络可以学习交通流量的时空变化规律,为交通管理和规划提供科学依据。
  5. 能源系统:图卷积网络可以用于电力系统的优化和管理,预测能源消耗,提高能源效率。通过构建电力网络图,图卷积网络可以学习电力负荷的时空分布特征,为电力调度和节能管理提供决策支持。
  6. 自然语言处理:图卷积网络能够对句子进行结构化建模,从而提升一些NLP任务的性能,如文本分类、情感分析等。通过将文本表示为图结构,图卷积网络可以学习文本之间的依赖关系和语义信息,提高文本处理的准确性和效率。

此外,图卷积网络还在图像分类、图像检索、推荐系统等领域有着广泛的应用。

2. 基本理论和计算

         图卷积神经网络可以被划分为谱域图卷积和空域图卷积。在本文中,我们主要介绍空域图卷积。空域图卷积将卷积操作直接定义在每个节点的连接关系上,通过节点的连接关系来实现节点间信息的传播与更新。通常,一个图可以被表征为$G(X,A,E)$, 其中X表示节点特征集,$A \in {R^{n \times n}}$为邻接矩阵,E为节点边连接属性,n为节点数。通过引入自环,新的邻接矩阵可以被表征为:

$\tilde A = A + I$

其中,$I \in {R^{n \times n}}$为单位矩阵。令对角矩阵D表示度矩阵则其满足:

${d_{ii}} = \sum\nolimits_{k = 1}^n {​{​{\tilde A}_{ik}}} $

其中,${d_{ii}}$为度矩阵第i行i列元素, ${\tilde A}_{ik}$为邻接矩阵第i行k列元素。为防止多边节点特征值过大,对邻接矩阵进行对称归一化,既${D^{ - \frac{1}{2}}}\tilde A{D^{ - \frac{1}{2}}}$。因此,空域图卷积运算可以被计算为:

${H^1} = \sigma ({D^{ - \frac{1}{2}}}\tilde A{D^{ - \frac{1}{2}}}X{W^1})$

其中,X为经过一次图卷积运算后的节点特征集,$\sigma$为激活函数, W为可学习参数。值得注意的是,每次图卷积运算均会聚合节点1-hop邻居的特征,通过叠加n次图卷积运算可以获取n-hop邻居节点的信息。

 

3. 代码实现

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.parameter import Parameter
  4. import torch.nn.init as init
  5. import math
  6. import numpy as np
  7. import torch.nn.functional as F
  8. class GraphConvolution(nn.Module):
  9. def __init__(self, input_dim: int, output_dim: int, bias: bool = True, device=None, dtype=None) -> None:
  10. """
  11. 空域图卷积:
  12. :param input_dim 输入单节点特征数
  13. :param output_dim 输出单节点特征数
  14. :param use_bias 是否使用偏置
  15. """
  16. super(GraphConvolution, self).__init__()
  17. factory_kwargs = {'device': device, 'dtype': dtype}
  18. self.input_dim = input_dim
  19. self.output_dim = output_dim
  20. self.use_bias = bias
  21. self.weight = Parameter(torch.empty((input_dim, output_dim), **factory_kwargs))
  22. if bias:
  23. self.bias = Parameter(torch.empty(output_dim, **factory_kwargs))
  24. else:
  25. self.register_parameter('bias', None)
  26. self.reset_parameters()
  27. def reset_parameters(self) -> None:
  28. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  29. # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
  30. # https://github.com/pytorch/pytorch/issues/57109
  31. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  32. if self.bias is not None:
  33. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  34. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  35. init.uniform_(self.bias, -bound, bound)
  36. def forward(self, input_feature, adjacency):
  37. """
  38. adjacency: torch.FloatTensor 邻接矩阵 [nodes, nodes]
  39. input_feature: torch.Tensor 输入节点集合: [batch, nodes, features]
  40. process:
  41. H(l+1) = D^(-1/2)*A*D^(-1/2)*H(l)*W
  42. """
  43. adjacency = self.standard_adjacency(adjacency) # (8, 18, 8, 8)
  44. support = torch.einsum("bij,bjk->bik", [adjacency, input_feature])
  45. output = torch.matmul(support, self.weight)
  46. if self.use_bias:
  47. output += self.bias
  48. return output
  49. def standard_adjacency(self, adjacency): # 注意这里的邻接矩阵应带有自环
  50. """
  51. :param adjacency: 邻接矩阵 [batch, nodes, nodes]
  52. :return: 标准化邻接矩阵: [batch, nodes, nodes]
  53. """
  54. degree_matrix = torch.sum(adjacency, dim=-1, keepdim=False) # [8, 18, 8]
  55. degree_matrix = degree_matrix.pow(-0.5)
  56. degree_matrix[degree_matrix == float("inf")] = 0. # [64, 18]
  57. degree_matrix = degree_matrix.reshape(-1, len(degree_matrix[0]), 1) * torch.eye(
  58. len(degree_matrix[0]), len(degree_matrix[0])) # (16, 8, 8)
  59. return torch.matmul(torch.matmul(degree_matrix, 1.0 * adjacency), degree_matrix) # [64, 18, 18]
  60. if __name__ == "__main__":
  61. graph_conv = GraphConvolution(32, 64)
  62. # 这是我们随意设置的一个邻接矩阵,16 表示batch
  63. adjacency = torch.zeros(16, 8, 8) + torch.Tensor(np.eye(8, k=-1) + np.eye(8) + np.eye(8, k=1))
  64. x = torch.randn(16, 8, 32) # 输入一个batch 的节点特征
  65. h = F.relu(graph_conv(x, adjacency)) # 输出更新后的节点特征, 并使用relu作为激活函数
  66. print(h.shape)
  67. print(h)

 运行:

要注意,所需要的包一定要导全。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/767226
推荐阅读
相关标签
  

闽ICP备14008679号