当前位置:   article > 正文

PyTorch Geometric基本教程_pytorch geometric networkx

pytorch geometric networkx

PyG官方文档


  1. # Install torch geometric
  2. !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
  3. !pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
  4. !pip install -q torch-geometric
  5. import torch
  6. import networkx as nx
  7. import matplotlib.pyplot as plt

1.内置数据集(以KarateClub为例)

  1. from torch_geometric.datasets import KarateClub
  2. dataset = KarateClub()
  3. print(f'Dataset: {dataset}:')
  4. print('======================')
  5. # 图的数量
  6. print(f'Number of graphs: {len(dataset)}')
  7. # 每个节点的特征尺寸
  8. print(f'Number of features: {dataset.num_features}')
  9. # 节点的类别数量
  10. print(f'Number of classes: {dataset.num_classes}')
  1. # 获取具体的图
  2. data = dataset[0]
  3. print(data)
  4. print('==============================================================')
  5. # 获取图的属性
  6. print(f'Number of nodes: {data.num_nodes}')
  7. print(f'Number of edges: {data.num_edges}')
  8. print(f'Average node degree: {(2*data.num_edges) / data.num_nodes:.2f}')
  9. print(f'Number of training nodes: {data.train_mask.sum()}')
  10. print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
  11. print(f'Contains isolated nodes: {data.has_isolated_nodes()}')
  12. print(f'Contains self-loops: {data.has_self_loops()}')
  13. print(f'Is undirected: {data.is_undirected()}')
  1. # 取出的图的数据对象为Data类型,包含以下属性
  2. # 1. edge_index 每条边的两个端点的索引组成的元组
  3. # 2. x 节点特征[节点数量,特征维数]
  4. # 3. y 节点标签(类别),每个节点只分配一个类别
  5. # 4. train_mask
  6. Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])
  7. print(data)
  1. # 获取所有的边
  2. print(data.edge_idx.T)

2.可视化

  1. def visualize(h, color, epoch=None, loss=None, accuracy=None):
  2. plt.figure(figsize=(7,7))
  3. plt.xticks([])
  4. plt.yticks([])
  5. if torch.is_tensor(h):
  6. h = h.detach().cpu().numpy()
  7. plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
  8. if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:
  9. plt.xlabel((f'Epoch: {epoch}, Loss: {loss.item():.4f} \n'
  10. f'Training Accuracy: {accuracy["train"]*100:.2f}% \n'
  11. f' Validation Accuracy: {accuracy["val"]*100:.2f}%'),
  12. fontsize=16)
  13. else:
  14. # networkx的draw_networkx
  15. nx.draw_networkx(h, pos=nx.spring_layout(h, seed=42), with_labels=False, node_color=color, cmap="Set2")
  16. plt.show()
'
运行
  1. from torch_geometric.utils import to_networkx
  2. # 将Data类型转换成networkx
  3. G = to_networkx(data, to_undirected=True)
  4. # 将图可视化,节点颜色为节点的类型
  5. visualize(G, color=data.y)

3.搭建GNN(以GCN为例)

  1. import torch
  2. from torch.nn import Linear
  3. from torch_geometric.nn import GCNConv
  4. class GCN(torch.nn.Module):
  5. def __init__(self):
  6. super().__init__()
  7. self.conv1 = GCNConv(dataset.num_features, 4)
  8. self.conv2 = GCNConv(4, 4)
  9. self.conv3 = GCNConv(4, 2)
  10. self.classifier = Linear(2, dataset.num_classes)
  11. def forward(self, x, edge_index):
  12. h = self.conv1(x, edge_index)
  13. h = h.tanh()
  14. h = self.conv2(h, edge_index)
  15. h = h.tanh()
  16. h = self.conv3(h, edge_index)
  17. h = h.tanh()
  18. out = self.classifier(h)
  19. return out, h
  20. model = GCN()
  21. print(model)
  1. # 节点分类
  2. model = GCN()
  3. out, h = model(data.x, data.edge_index)
  4. print(f'Embedding shape: {list(h.shape)}')
  5. visualize(h, color=data.y)

4.在KarateClub数据集上训练

  1. import time
  2. model = GCN()
  3. # 交叉熵损失,Adam优化器
  4. criterion = torch.nn.CrossEntropyLoss()
  5. optimizer = torch.optim.Adam(model.parameters())
  6. def train(data):
  7. optimizer.zero_grad()
  8. out, h = model(data.x, data.edge_index)
  9. # 只对train_mask的节点计算loss
  10. loss = criterion(out[data.train_mask], data.y[data.train_mask])
  11. loss.backward()
  12. optimizer.step()
  13. accuracy = {}
  14. # torch.argmax 取置信度最大的一类
  15. predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
  16. target_classes = data.y[data.train_mask]
  17. accuracy['train'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
  18. predicted_classes = torch.argmax(out, axis=1)
  19. target_classes = data.y
  20. accuracy['val'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
  21. return loss, h, accuracy
  1. for epoch in range(500):
  2. loss, h, accuracy = train(data)
  3. if epoch % 10 == 0:
  4. visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy)
  5. time.sleep(0.3)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/860472
推荐阅读
相关标签
  

闽ICP备14008679号