赞
踩
创建一个图,信息如下:
定义数据:x是每个点的输入特征,y是每个点的标签。x的维度为[M,F],M表示结点数,F表示特征个数
x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
定义邻接矩阵:顺序是无所谓的,上下两种是一样的
edge_index = torch.tensor([[0, 1, 2, 0, 3],#起始点
[1, 0, 1, 3, 2]], dtype=torch.long)#终止点
edge_index = torch.tensor([[0, 2, 1, 0, 3],
[3, 1, 0, 1, 2]], dtype=torch.long)
创建torch_geometric中的图,通过torch_geometric.data
- from torch_geometric.data import Data
-
- x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
- y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
-
- edge_index =
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。