当前位置:   article > 正文

深入理解图注意力网络代码(GAT)_gat代码详解

gat代码详解

目录

一、整体代码

二、解释 拼接操作

一)创建所有可能的配对

二)拼接以形成配对

三)示例

①假设 out_features 为 2,我们的序列 h 为

②h.repeat(1, N) 生成的矩阵将是

③这里插播一个pytorch中view的使用方法

④因此h.repeat(1, N).view(N*N, -1) 的输出结果是

⑤h.repeat(N, 1) 生成的矩阵将是

⑥torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).拼接后的矩阵是


一、整体代码

  1. class GraphAttentionLayer(nn.Module):
  2. """
  3. Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
  4. 图注意力层
  5. """
  6. def __init__(self, in_features, out_features, dropout, alpha, concat=True):
  7. super(GraphAttentionLayer, self).__init__()
  8. self.in_features = in_features # 节点表示向量的输入特征维度
  9. self.out_features = out_features # 节点表示向量的输出特征维度
  10. self.dropout = dropout # dropout参数
  11. self.alpha = alpha # leakyrelu激活的参数
  12. self.concat = concat #
  13. # 定义可训练参数,即论文中的W和a
  14. self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
  15. nn.init.xavier_uniform_(self.W.data, gain=1.414) # xavier初始化
  16. self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
  17. nn.init.xavier_uniform_(self.a.data, gain=1.414) # xavier初始化
  18. # 定义leakyrelu激活函数
  19. self.leakyrelu = nn.LeakyReLU(self.alpha)
  20. def forward(self, inp, adj):
  21. """
  22. inp: input_fea [N, in_features] in_features表示节点的输入特征向量元素个数
  23. adj: 图的邻接矩阵 维度[N, N] 非零即一,数据结构基本知识
  24. """
  25. h = torch.mm(inp, self.W) # [N, out_features]
  26. N = h.size()[0] # N 图的节点数
  27. a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)
  28. # [N, N, 2*out_features]
  29. e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
  30. # [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
  31. zero_vec = -1e12 * torch.ones_like(e) # 将没有连接的边置为负无穷
  32. attention = torch.where(adj>0, e, zero_vec) # [N, N]
  33. # 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留,
  34. # 否则需要mask并置为非常小的值,原因是softmax的时候这个最小值会不考虑。
  35. attention = F.softmax(attention, dim=1) # softmax形状保持不变 [N, N],得到归一化的注意力权重!
  36. attention = F.dropout(attention, self.dropout, training=self.training) # dropout,防止过拟合
  37. h_prime = torch.matmul(attention, h) # [N, N].[N, out_features] => [N, out_features]
  38. # 得到由周围节点通过注意力权重进行更新的表示
  39. if self.concat:
  40. return F.elu(h_prime)
  41. else:
  42. return h_prime
  43. def __repr__(self):
  44. return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

二、解释 拼接操作

也就是下面这个代码:

a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)

它构建了一个特定的数据结构,这个结构允许每个元素都能与序列中的每个其他元素配对。

一)创建所有可能的配对

h.repeat(1, N).view(N*N, -1)

这个操作重复了 h 中的每个元素 N 次,创建了一个包含所有可能的“行”配对的张量。例如,如果 h 是一个序列中的元素,这个操作就创建了一个包含这个元素与序列中每个元素(包括它自己)的配对的张量。

h.repeat(N, 1)

这个操作重复整个 h N 次,创建了一个包含所有可能的“列”配对的张量。

二)拼接以形成配对

通过 torch.cat,这两个重复的张量被沿着特定的维度拼接起来。这意味着对于 h 中的每个元素,你现在有了一个包含了它与序列中每个其他元素的配对的完整集合。

三)示例

①假设 out_features 为 2,我们的序列 h
  1. h1=[1,2]
  2. h2=[3,4]
  3. h3=[5,6]

②h.repeat(1, N) 生成的矩阵将是
  1. [1,2],[1,2],[1,2]
  2. [3,4],[3,4],[3,4]
  3. [5,6],[5,6],[5,6]
③这里插播一个pytorch中view的使用方法
  1. #初始化一个tensor
  2. import torch
  3. a1 = torch.arange(0,16)
  4. print(a1)
  5. #输出为:tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
  6. a2 = a1.view(8, 2)
  7. a3 = a1.view(2, 8)
  8. a4 = a1.view(4, 4)
  9. print(a2)
  10. print(a3)
  11. print(a4)

输出为:

  1. tensor([[ 0, 1],
  2. [ 2, 3],
  3. [ 4, 5],
  4. [ 6, 7],
  5. [ 8, 9],
  6. [10, 11],
  7. [12, 13],
  8. [14, 15]])
  9. tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
  10. [ 8, 9, 10, 11, 12, 13, 14, 15]])
  11. tensor([[ 0, 1, 2, 3],
  12. [ 4, 5, 6, 7],
  13. [ 8, 9, 10, 11],
  14. [12, 13, 14, 15]])
④因此h.repeat(1, N).view(N*N, -1) 的输出结果是
 
  1. [1,2]
  2. [1,2]
  3. [1,2]
  4. [3,4]
  5. [3,4]
  6. [3,4]
  7. [5,6]
  8. [5,6]
  9. [5,6]
h.repeat(N, 1) 生成的矩阵将是
  1. [1,2]
  2. [3,4]
  3. [5,6]
  4. [1,2]
  5. [3,4]
  6. [5,6]
  7. [1,2]
  8. [3,4]
  9. [5,6]
⑥torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).拼接后的矩阵是
  1. [1,2], [1,2]
  2. [1,2], [3,4]
  3. [1,2], [5,6]
  4. [3,4], [1,2]
  5. [3,4], [3,4]
  6. [3,4], [5,6]
  7. [5,6], [1,2]
  8. [5,6], [3,4]
  9. [5,6], [5,6]

这个矩阵包含了序列中每个元素对(例如 [h_1, h_1], [h_1, h_2], [h_1, h_3] 等)的组合。这为计算自注意力机制中的每个元素与序列中其他所有元素之间的关系提供了基础。每一行代表一个唯一的元素配对,这使得模型能够针对每个配对计算注意力分数。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/571528
推荐阅读
相关标签
  

闽ICP备14008679号