当前位置:   article > 正文

图神经网络改进-手把手教你改代码-第4期:GCN、GAT手动实现_实现gat的改进

实现gat的改进

  本系列项目主攻:代码分享与讲解创新思路解析前沿模块缝合二次创新实现方法
项目主要提供关于:

  • 图神经网络
  • 图对比学习
  • 图结构学习
  • 超图神经网络
  • 超图对比学习
  • 超图结构学习

  这六种方向的通用模型、原创代码以及改进思路,供大家参考学习,后续还会持续更新针对链路预测、节点分类等下游任务上的代码以及改进思路,帮助大家提升代码水平,多发论文。


 希望可以帮助大家快速上手实践图神经网络实践是最好的入门方式!

  祝大家论文顺利,accept冲冲冲!

第4期:手动实现GCN、GAT模块


本期为大家带来了图神经网络基础模块:GCN、GAT的复现代码,并附有逐行注释和视频讲解


Q:为什么要自己实现这些图神经网络基础模块?
A:PYG虽然提供了许多基础模块供直接调用,但是不利于基础创新,或不方便获取这些模块的中间输出等。


详细讲解视频:【图神经网络改进-手把手教你改代码-第4期】

项目Github:图小狮


以下为具体的代码部分:

GCN类


1.类继承自nn.Module

class GCNConv(nn.Module):
  • 1

2.__init__方法

def __init__(self, in_features, out_features):
        super(GCNConv, self).__init__()
        # 输入特征的维度
        self.in_features = in_features
        # 输出特征的维度
        self.out_features = out_features
        # 定义可学习的权重矩阵
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        # 定义可学习的偏置向量
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        # 初始化参数
        self.reset_parameters()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3.参数初始化

def reset_parameters(self):
        # 使用Xavier初始化权重矩阵
        nn.init.xavier_uniform_(self.weight)
        # 将偏置向量初始化为0
        nn.init.zeros_(self.bias)
  • 1
  • 2
  • 3
  • 4
  • 5

4.forward方法

def forward(self, x, edge_index):
        # x: 节点特征矩阵,大小为[num_nodes, in_features]
        # edge_index: 边索引,大小为[2, num_edges]

        # 计算规范化的邻接矩阵
        num_nodes = x.size(0)
        # 创建一个全零矩阵作为邻接矩阵的初始状态
        adj = torch.zeros(num_nodes, num_nodes).cuda()
        # 根据edge_index填充邻接矩阵,无向图因此两个方向都要填充
        adj[edge_index[0], edge_index[1]] = 1
        adj[edge_index[1], edge_index[0]] = 1
        # 计算每个节点的度
        deg = torch.sum(adj, dim=1)
        # 计算度矩阵的逆平方根,用于后续的归一化
        deg_inv_sqrt = 1.0 / torch.sqrt(deg)  # 加上一个小的常数避免除零错误
        # 计算对称归一化的邻接矩阵
        norm_adj = adj * deg_inv_sqrt.unsqueeze(1)
        norm_adj = norm_adj * deg_inv_sqrt.unsqueeze(0)

        # 支撑传播:线性变换节点特征
        support = torch.matmul(x, self.weight)  # X * W
        # 消息传递:通过归一化的邻接矩阵传播特征
        output = torch.matmul(norm_adj, support) + self.bias  # D^{-1/2} * A * D^{-1/2} * (X * W) + b
        return output  # 返回输出,可以选择使用ReLU等激活函数进行非线性变换
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

GAT类


1.类继承自nn.Module

class GCNConv(nn.Module):
  • 1

2.__init__方法

def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2, concat=True):
        super(GATConv, self).__init__()
        # 定义dropout率,用于在注意力系数上进行dropout操作以防止过拟合
        self.dropout = dropout
        # 输入特征的维度
        self.in_features = in_features
        # 输出特征的维度
        self.out_features = out_features
        # LeakyReLU非线性激活函数中的负斜率alpha
        self.alpha = alpha
        # 是否在多头注意力中进行拼接,对于最后一层通常设为False,使用平均
        self.concat = concat
        # 定义可学习的权重矩阵W,用于线性变换输入特征
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        # 定义注意力机制中可学习的参数a
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        # 定义LeakyReLU激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

3.注意力系数矩阵生成

def _prepare_attentional_mechanism_input(self, Wh):
        # 这个函数负责计算注意力系数
        # 首先通过与a的前半部分做矩阵乘法计算得到每个节点的影响力分数Wh1
        Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
        # 通过与a的后半部分做矩阵乘法计算得到每个节点被影响的分数Wh2
        Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
        # 将Wh1加上Wh2的转置,得到每一对节点的非归一化注意力分数e
        e = Wh1 + Wh2.T
        # 使用LeakyReLU激活函数处理e,增加非线性
        return self.leakyrelu(e)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

4.forward方法

def forward(self, h, edge_index):
        # 将边索引转换为稠密邻接矩阵,并去除多余的维度
        adj = to_dense_adj(edge_index).squeeze(0)
        # 应用线性变换
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        # 准备注意力机制的输入
        e = self._prepare_attentional_mechanism_input(Wh)
        # 创建一个足够小的向量用于掩盖不存在的边
        zero_vec = -9e15*torch.ones_like(e)
        # 只有当adj中存在边时,才保留e中的值,否则用zero_vec中的极小值代替
        attention = torch.where(adj > 0, e, zero_vec)
        # 对注意力系数进行softmax操作,使得每个节点的注意力系数和为1
        attention = F.softmax(attention, dim=1)
        # 对注意力系数进行dropout
        attention = F.dropout(attention, self.dropout, training=self.training)
        # 应用注意力机制更新节点特征
        h_prime = torch.matmul(attention, Wh)

        # 如果concat为真,则对输出使用ELU激活函数;否则直接返回结果
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

Powered By 图小狮

希望能够得到大家的喜欢,您的点赞收藏即是对我们最大的支持!

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号