当前位置:   article > 正文

【DGL系列】remove_nodes从graph中删除节点

【DGL系列】remove_nodes从graph中删除节点

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~


背景说明

从graph中删除节点在dgl中提供了两种形式:

  • dgl.remove_nodes:返回新的graph
  • dgl.DGLGraph.remove_nodes:直接在原来的graph上操作

他们的用法类似,入口参数也相似。

dgl.remove_nodes

dgl.remove_nodes — DGL 2.3 documentation

函数说明:

        删除指定的节点并返回一个新graph。同时删除相应的特征,从节点相连的边也将被移除。删除后,DGL 会使用 ID 从 0 开始的剩余节点和边重新标记。

函数定义:

        dgl.remove_nodes(gnidsntype=Nonestore_ids=False)

参数说明:

  • g (DGLGraph) – 要删除的graph。
  • nids (int, Tensor, iterable[int]) - 要删除的节点。
  • ntype (str, 可选) - 要删除的节点的类型。如果图中只有一种节点类型,则可以省略。
  • store_ids (bool, 可选) – 如果为 True,它将在结果图的 ndata 和 edata 中存储提取的节点和边的原始 ID,分别名为 dgl.NID 和 dgl.EID

注意事项:

        此函数将丢弃批处理信息。请使用 dgl.DGLGraph.set_batch_num_nodes 和 dgl.DGLGraph.set_batch_num_edges 来维护信息。

        当设置 store_ids=True 时,DGL 会在图的内部存储被删除的节点 ID。这在需要后续访问这些节点 ID 时特别有用。

代码示例:

  1. import dgl
  2. import torch
  3. # 定义图的边
  4. src_nodes = torch.tensor([0, 1, 2, 3, 4]) # 起始节点
  5. dst_nodes = torch.tensor([1, 2, 3, 4, 5]) # 结束节点
  6. # 创建图对象
  7. g = dgl.graph((src_nodes, dst_nodes))
  8. # 图是无向的,所以添加反向边
  9. g = dgl.to_bidirected(g)
  10. print("删除前的图:", g)
  11. print("删除前的节点IDs:", g.nodes())
  12. # 删除节点 2
  13. new_g = dgl.remove_nodes(g, torch.tensor([2]))
  14. print("删除后的图:", new_g)
  15. print("删除后,新的节点IDs:", new_g.nodes())
  16. print("#" * 50 )
  17. # 删除节点 2,存储删除的节点 ID
  18. new_g = dgl.remove_nodes(g, torch.tensor([2]), store_ids=True)
  19. print("删除后的图:", new_g)
  20. print("删除后,新的节点IDs:", new_g.nodes())
  21. print("删除后,保留的原始节点IDs:", new_g.ndata[dgl.NID])

dgl.DGLGraph.remove_nodes

dgl.DGLGraph.remove_nodes — DGL 2.3 documentation

函数说明:

        删除具有指定节点类型的多个节点,连接到节点的边也将被移除。删除节点和边后,将使用从 0 开始的连续整数重新索引其余节点和边,并保留它们的相对顺序。已删除节点/边缘的特征将相应地移除。

函数定义:

        DGLGraph.remove_nodes(nidsntype=Nonestore_ids=False)

参数说明:

  • nids (int, tensor, numpy.ndarray, list) - 要删除的节点。
  • ntype (str, 可选) - 要删除的节点的类型。如果图中只有一种节点类型,则可以省略。
  • store_ids (bool, 可选) – 如果为 True,它将在结果图的 ndata 和 edata 中存储提取的节点和边的原始 ID,分别名为 dgl.NID 和 dgl.EID

注意事项:

        此函数保留批处理信息。

        当设置 store_ids=True 时,DGL 会在图的内部存储被删除的节点 ID。这在需要后续访问这些节点 ID 时特别有用。

代码示例:

  1. import dgl
  2. import torch
  3. # 定义图的边
  4. src_nodes = torch.tensor([0, 1, 2, 3, 4]) # 起始节点
  5. dst_nodes = torch.tensor([1, 2, 3, 4, 5]) # 结束节点
  6. # 创建图对象
  7. g = dgl.graph((src_nodes, dst_nodes))
  8. # 图是无向的,所以添加反向边
  9. g = dgl.to_bidirected(g)
  10. print("删除前的图:", g)
  11. print("删除前的节点IDs:", g.nodes())
  12. # 删除节点 2
  13. g.remove_nodes(torch.tensor([2]))
  14. print("删除后的图:", g)
  15. print("删除后,新的节点IDs:", g.nodes())
  16. print("#" * 50 )
  17. # 删除节点 2,存储删除的节点 ID
  18. g = dgl.graph((src_nodes, dst_nodes))
  19. g = dgl.to_bidirected(g)
  20. g.remove_nodes(torch.tensor([2]), store_ids=True)
  21. print("删除后的图:", g)
  22. print("删除后,新的节点IDs:", g.nodes())
  23. print("删除后,保留的原始节点IDs:", g.ndata[dgl.NID])

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号