赞
踩
转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]
如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
从graph中删除节点在dgl中提供了两种形式:
他们的用法类似,入口参数也相似。
函数说明:
删除指定的节点并返回一个新graph。同时删除相应的特征,从节点相连的边也将被移除。删除后,DGL 会使用 ID 从 0 开始的剩余节点和边重新标记。
函数定义:
dgl.remove_nodes(g, nids, ntype=None, store_ids=False)
参数说明:
ndata
和 edata
中存储提取的节点和边的原始 ID,分别名为 dgl.NID
和 dgl.EID
。注意事项:
此函数将丢弃批处理信息。请使用 dgl.DGLGraph.set_batch_num_nodes 和 dgl.DGLGraph.set_batch_num_edges 来维护信息。
当设置 store_ids=True
时,DGL 会在图的内部存储被删除的节点 ID。这在需要后续访问这些节点 ID 时特别有用。
代码示例:
- import dgl
- import torch
-
- # 定义图的边
- src_nodes = torch.tensor([0, 1, 2, 3, 4]) # 起始节点
- dst_nodes = torch.tensor([1, 2, 3, 4, 5]) # 结束节点
- # 创建图对象
- g = dgl.graph((src_nodes, dst_nodes))
- # 图是无向的,所以添加反向边
- g = dgl.to_bidirected(g)
-
- print("删除前的图:", g)
- print("删除前的节点IDs:", g.nodes())
-
- # 删除节点 2
- new_g = dgl.remove_nodes(g, torch.tensor([2]))
- print("删除后的图:", new_g)
- print("删除后,新的节点IDs:", new_g.nodes())
-
- print("#" * 50 )
-
- # 删除节点 2,存储删除的节点 ID
- new_g = dgl.remove_nodes(g, torch.tensor([2]), store_ids=True)
- print("删除后的图:", new_g)
- print("删除后,新的节点IDs:", new_g.nodes())
- print("删除后,保留的原始节点IDs:", new_g.ndata[dgl.NID])
函数说明:
删除具有指定节点类型的多个节点,连接到节点的边也将被移除。删除节点和边后,将使用从 0 开始的连续整数重新索引其余节点和边,并保留它们的相对顺序。已删除节点/边缘的特征将相应地移除。
函数定义:
DGLGraph.remove_nodes(nids, ntype=None, store_ids=False)
参数说明:
ndata
和 edata
中存储提取的节点和边的原始 ID,分别名为 dgl.NID
和 dgl.EID
。注意事项:
此函数保留批处理信息。
当设置 store_ids=True
时,DGL 会在图的内部存储被删除的节点 ID。这在需要后续访问这些节点 ID 时特别有用。
代码示例:
- import dgl
- import torch
-
- # 定义图的边
- src_nodes = torch.tensor([0, 1, 2, 3, 4]) # 起始节点
- dst_nodes = torch.tensor([1, 2, 3, 4, 5]) # 结束节点
- # 创建图对象
- g = dgl.graph((src_nodes, dst_nodes))
- # 图是无向的,所以添加反向边
- g = dgl.to_bidirected(g)
-
- print("删除前的图:", g)
- print("删除前的节点IDs:", g.nodes())
-
- # 删除节点 2
- g.remove_nodes(torch.tensor([2]))
- print("删除后的图:", g)
- print("删除后,新的节点IDs:", g.nodes())
-
- print("#" * 50 )
-
- # 删除节点 2,存储删除的节点 ID
- g = dgl.graph((src_nodes, dst_nodes))
- g = dgl.to_bidirected(g)
- g.remove_nodes(torch.tensor([2]), store_ids=True)
- print("删除后的图:", g)
- print("删除后,新的节点IDs:", g.nodes())
- print("删除后,保留的原始节点IDs:", g.ndata[dgl.NID])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。