当前位置:   article > 正文

【Graph Net】【专题系列】四、GAT代码实战_gat训练代码

gat训练代码

【Graph Net】【专题系列】四、GAT代码实战

目录

一、简介

二、代码

三、模型结构、结果分析

四、展望


一、简介

        GAT(Graph Attention Network)在图结构的基础上,加上了注意力这个东东,这期我不装了,图NNer不搞PyG,感觉就是对不起巨人们的肩膀!此处吹爆PyG!

        本文目标就是基于PyG,写个GAT的Demo来实现Cora图分类,这里Cora的读取我自己用sklearn和文本读取,也算是我自己的一点点贡献,方便后者用自己的数据集做二次开发。话不多说,直接上代码。 图的相关代码可见仓库:GitHub - mapstory6788/Graph-Networks

二、代码

  1. import os
  2. import time
  3. import random
  4. import torch
  5. import torch.nn.functional as F
  6. from torch_geometric.nn import GATConv
  7. from torch_geometric.data import Data
  8. import numpy as np
  9. import pandas as pd
  10. import scipy.sparse as sp
  11. from sklearn.preprocessing import LabelEncoder
  12. #配置项
  13. class configs():
  14. def __init__(self):
  15. # Data
  16. self.data_path = r'./data/cora'
  17. self.save_model_dir = './'
  18. self.model_name = r'GAT'
  19. self.seed = 2023
  20. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  21. self.epoch = 500
  22. self.in_features = 1433 #core ~ feature:1433
  23. self.hidden_features = 16 # 隐层数量
  24. self.output_features = 8 # core~paper-point~ 8类
  25. self.learning_rate = 0.01
  26. self.dropout = 0.5
  27. self.istrain = True
  28. self.istest = True
  29. cfg = configs()
  30. def seed_everything(seed=2023):
  31. random.seed(seed)
  32. os.environ['PYTHONHASHSEED']=str(seed)
  33. np.random.seed(seed)
  34. torch.manual_seed(seed)
  35. seed_everything(seed = cfg.seed)
  36. # 读取Cora数据集 return geometric Data格式
  37. def index_to_mask(index, size):
  38. mask = np.zeros(size, dtype=bool)
  39. mask[index] = True
  40. return mask
  41. def load_cora_data(data_path = cfg.data_path):
  42. content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)
  43. content_df.set_index(0, inplace=True)
  44. index = content_df.index.tolist()
  45. features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)
  46. # 处理标签
  47. labels = content_df.values[:,-1]
  48. class_encoder = LabelEncoder()
  49. labels = class_encoder.fit_transform(labels)
  50. # 读取引用关系
  51. cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)
  52. cites_df[0] = cites_df[0].astype(str)
  53. cites_df[1] = cites_df[1].astype(str)
  54. cites = [tuple(x) for x in cites_df.values]
  55. edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]
  56. edges = np.array(edges).T
  57. # 构造Data对象
  58. data = Data(x=torch.from_numpy(np.array(features.todense())),
  59. edge_index=torch.LongTensor(edges),
  60. y=torch.from_numpy(labels))
  61. idx_train = range(140)
  62. idx_val = range(200, 500)
  63. idx_test = range(500, 1500)
  64. data.train_mask = index_to_mask(idx_train, size=labels.shape[0])
  65. data.val_mask = index_to_mask(idx_val, size=labels.shape[0])
  66. data.test_mask = index_to_mask(idx_test, size=labels.shape[0])
  67. return data
  68. class GAT(torch.nn.Module):
  69. def __init__(self, in_channels, out_channels, heads=8, dropout=cfg.dropout, bias=True):
  70. super(GAT, self).__init__()
  71. self.conv1 = GATConv(in_channels, out_channels, heads=heads, concat=True, dropout=dropout, bias=bias)
  72. self.conv2 = GATConv(heads * out_channels, out_channels, heads=heads, concat=False, dropout=dropout, bias=bias)
  73. def forward(self, data):
  74. x, edge_index = data.x, data.edge_index
  75. x = F.dropout(x, p=0.6, training=self.training)
  76. x = F.elu(self.conv1(x, edge_index))
  77. x = F.dropout(x, p=0.6, training=self.training)
  78. x = self.conv2(x, edge_index)
  79. return F.log_softmax(x, dim=1)
  80. class myGAT_run():
  81. def train(self):
  82. t = time.time()
  83. dataset = load_cora_data()
  84. model = GAT(dataset.num_features, cfg.output_features).to(cfg.device)
  85. data = dataset
  86. optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=5e-4)
  87. model.train()
  88. for epoch in range(cfg.epoch):
  89. optimizer.zero_grad()
  90. output = model(data)
  91. preds = output.max(dim=1)[1]
  92. loss_train = F.nll_loss(output[data.train_mask], data.y[data.train_mask].long())
  93. correct = preds[data.train_mask].eq(data.y[data.train_mask]).sum().item()
  94. acc_train = correct / int(data.train_mask.sum())
  95. loss_train.backward()
  96. optimizer.step()
  97. loss_val = F.nll_loss(output[data.val_mask], data.y[data.val_mask].long())
  98. correct = preds[data.val_mask].eq(data.y[data.val_mask]).sum().item()
  99. acc_val = correct / int(data.val_mask.sum())
  100. print('Epoch: {:04d}'.format(epoch + 1),
  101. 'loss_train: {:.4f}'.format(loss_train.item()),
  102. 'acc_train: {:.4f}'.format(acc_train),
  103. 'loss_val: {:.4f}'.format(loss_val.item()),
  104. 'acc_val: {:.4f}'.format(acc_val),
  105. 'time: {:.4f}s'.format(time.time() - t))
  106. torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存
  107. def infer(self):
  108. #Create Test Processing
  109. dataset = load_cora_data()
  110. data = dataset
  111. model_path = os.path.join(cfg.save_model_dir, 'latest.pth')
  112. model = torch.load(model_path, map_location=torch.device(cfg.device))
  113. model.eval()
  114. output = model(data)
  115. params = sum(p.numel() for p in model.parameters())
  116. preds = output.max(dim=1)[1]
  117. loss_test = F.nll_loss(output[data.test_mask], data.y[data.test_mask].long())
  118. correct = preds[data.test_mask].eq(data.y[data.test_mask]).sum().item()
  119. acc_test = correct / int(data.test_mask.sum())
  120. print("Test set results:",
  121. "loss= {:.4f}".format(loss_test.item()),
  122. "accuracy= {:.4f}".format(acc_test),
  123. 'params={:.4f}k'.format(params/1024))
  124. if __name__ == '__main__':
  125. mygraph = myGAT_run()
  126. if cfg.istrain == True:
  127. mygraph.train()
  128. if cfg.istest == True:
  129. mygraph.infer()

三、模型结构、结果分析

Layer (type)Output ShapeParam #
Linear-1[-1, 64]91712
SumAggregation-2[-1, 8, 8]  0
GATConv-3[-1, 64]64
Linear-4[-1, 64]4096
SumAggregation-5[-1, 8, 8]0
GATConv-6  [-1, 8]8

Total params: 95,880≈93.88Kb,模型结构还是比较精简的。
在Cora数据集上,训练epoch=500,accuracy= 0.7590

四、展望

后面文章考虑从代码的角度,来研究下Graph Embedding,且与GNNs的联系。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/368851
推荐阅读
相关标签
  

闽ICP备14008679号