paper1”,应该是paper2引用了paper1。_pyg安装">
当前位置:   article > 正文

PyG安装及入门(一)

pyg安装

很早以前就想研究一下怎么用PyG,现在终于有一点时间了,开更!   、

一、安装

安装最麻烦的是对齐所有东西的版本,尤其是安装
这里涉及到的有 python-pytorch-cuda-cudnn-PyG相关框架

1.1 创建虚拟环境,安装pytorch

官方安装手册Start Locally | PyTorch

当中有提到python版本3.8-3.11现在基本上支持大多数的pytorch 版本了

Python 3.8-3.11 is generally installed by default on any of our supported Linux distributions, which meets our recommendation.

    所以直接用Python3.9创建虚拟环境
 

  1. # 创建虚拟环境
  2. conda create -n pyg python=3.9
  3. # 进入虚拟环境
  4. conda activate pyg
  5. # 查看所有虚拟环境
  6. conda env list

虚拟环境创建好以后,查看cuda版本,在官网上选择出对应的安装指令

但是官方指令常常会有安装缓慢的时候,具体用pip还是conda可以换着试试,换源也是常用的方法

1.2 安装PyG

官网安装手册Installation — pytorch_geometric documentation
这里我选择的是安装stable版本,最好是把相关的依赖一起安装上

二、导入PyG读取数据集

2.1 Cora数据集介绍

PyG内置了几种常用的数据集,这里主要用到Cora数据集,解决简单的分类问题
Cora数据集是一个机器学习论文数据集,统计了2078篇文章,内含.content、.cites两个文件
.content:通过统计机器学习不同领域中的key words在每篇论文中的出现,给出论文所属的分类,共有7个label
.cites: 记录了论文之间的引用关系,比如"paper1:paper2"代表有向图中的链路"paper2->paper1”,应该是paper2引用了paper1

TODO: 构造一个分类器,利用.cites中的论文引用关系判断论文所属分类,以.content中给出的label为基准,通过可视化、构造混淆矩阵等方式评判分类器的性能
Cora数据集虽然只有一张图,但是充分使用节点之间的连接关系构造节点特征,是图神经网络入门必不可少的数据集之一

2.2 PyG初步读取数据集

PyG的基础教程中给了一段读取Cora数据集的代码,参考Introduction by Example — pytorch_geometric documentation
 

  1. from torch_geometric.datasets import Planetoid
  2. dataset = Planetoid(root='/tmp/Cora', name='Cora')
  3. >>> Cora()
  4. len(dataset)
  5. >>> 1
  6. dataset.num_classes
  7. >>> 7
  8. dataset.num_node_features
  9. >>> 1433
  10. data = dataset[0]
  11. >>> Data(edge_index=[2, 10556], test_mask=[2708],
  12. train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
  13. data.is_undirected()
  14. >>> True
  15. data.train_mask.sum().item()
  16. >>> 140
  17. data.val_mask.sum().item()
  18. >>> 500
  19. data.test_mask.sum().item()
  20. >>> 1000

这里的data为每个节点分配了label,并有额外的node-level属性:
1.train_mask: denotes against which nodes to train (140 nodes)
2.val_mask: denotes which nodes to use for validation, e.g., to perform early stopping (500 nodes)
3.test_mask: denotes against which nodes to test (1000 nodes).

2.3 初步可视化

这里参考Pytorch Geometric 系列教程1:互动可视化Graph数据集 - MyEncyclopedia

将 cora 转换成 networkx 格式,Cora 有 7 种节点类型,将每种节点类型赋予不同颜色,调用 networkx 的 spring_layout 计算每个节点的弹簧布局下的位置
完整代码如下:

  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch_geometric.data import Data
  6. from torch_geometric.nn import GATConv
  7. from torch_geometric.datasets import Planetoid
  8. import torch_geometric.transforms as T
  9. name_data = 'Cora'
  10. dataset = Planetoid(root='./data/', name=name_data)
  11. from torch_geometric.utils import to_networkx
  12. cora = to_networkx(dataset.data)
  13. print(cora.is_directed())
  14. node_classes = dataset.data.y.data.numpy()
  15. print(node_classes)
  16. node_color = ["red","blue","green","yellow","peru","violet","cyan"]
  17. node_label = np.array(list(cora.nodes))
  18. import matplotlib.pyplot as plt
  19. import networkx as nx
  20. pos = nx.layout.spring_layout(cora)
  21. plt.figure(figsize=(16,12))
  22. for i in np.arange(len(np.unique(node_classes))):
  23. node_list = node_label[node_classes == i]
  24. nx.draw_networkx_nodes(cora, pos, nodelist=list(node_list),
  25. node_size=50,
  26. node_color=node_color[i],
  27. alpha=0.8)
  28. nx.draw_networkx_edges(cora, pos,width=1,edge_color="black")
  29. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/860440
推荐阅读
相关标签
  

闽ICP备14008679号