当前位置:   article > 正文

LAD-GNN标签注意蒸馏

LAD-GNN标签注意蒸馏

本文所涉及所有资源均在传知代码平台可获取。

介绍论文:Label Attentive Distillation for GNN-Based Graph Classification

在当今的数据科学领域,Graph Neural Networks (GNNs) 已成为处理图结构数据的强大工具。然而,传统的GNN在图分类任务中面临一个重要挑战——嵌入不对齐问题。本文将介绍一篇名为“Label Attentive Distillation for GNN-Based Graph Classification”的论文,该论文提出了一种新颖的解决方案——LAD-GNN,以显著提升图分类的性能。

论文地址

您可以在AAAI上找到这篇论文的详细内容。

知识蒸馏与GNN-MLP蒸馏

知识蒸馏

定义:一种将复杂模型(教师)的知识迁移到简单模型(学生)的技术。

步骤:训练教师模型。

○ 提取教师模型的知识。

○ 使用这些知识训练学生模型。

师生模型

教师模型:深度、复杂,捕捉复杂特征。

学生模型:简单、易于部署,模仿教师模型。

GNN-MLP 蒸馏

目的:将GNN的知识迁移到MLP,减少模型复杂性。

应用:移动设备。

○ 嵌入式系统。

○ 快速部署。

实施步骤

1. 训练GNN教师模型:在图分类任务上训练。

2. 提取知识:获取节点或图嵌入。

3. 训练MLP学生模型:模仿教师模型的行为。

论文主要创新内容

本文提出了一种新的图神经网络训练方法,称为 LAD-GNN。该方法通过标签注意蒸馏,显著提高了图分类任务的准确性。其主要思路是在训练过程中引入标签信息,通过师生模型架构,实现类友好的节点嵌入表示。

论文创新点

论文的主要创新点在于提出了一种名为标签注意蒸馏方法(LAD-GNN)的新颖方法。该方法通过引入标签注意编码器,将节点特征与标签信息结合在一起,生成更加理想的嵌入表示。标签注意编码器能够捕捉全局图信息,使得节点嵌入更加对齐,从而解决了传统GNN中常见的嵌入不对齐问题。此外,该方法采用了基于师生模型架构的蒸馏学习策略,教师模型通过标签注意编码器生成高质量的嵌入表示,学生模型通过蒸馏学习从教师模型中学习类友好的节点嵌入表示,从而优化图分类任务的性能。实验结果表明,LAD-GNN在多个基准数据集上显著提高了图分类的准确性,展示了其在图神经网络领域的创新性和有效性。

主要贡献

1. 解决嵌入不对齐问题:通过引入标签注意蒸馏方法,有效解决了传统GNN中的嵌入不对齐问题。

2. 提升分类准确性:在多个基准数据集上,LAD-GNN 显著提高了图分类的准确性。

3. 创新性方法:提出了标签注意编码器和蒸馏学习相结合的创新性方法。

模型图

以下是 LAD-GNN 的模型架构图:

LAD-GNN 模型架构

该框架图可以看到该框架分为教师模型和学生模型两个阶段。教师模型的训练过程是通过一种标签关注的训练方法进行的。在这个过程中,标签关注编码器会将真实标签编码成标签嵌入,并将其与由GNN骨干生成的节点嵌入结合,使用注意力机制形成一个理想的嵌入。这个理想嵌入被送入读出函数和分类头,以预测图的标签。标签关注编码器与GNN骨干一起训练,目的是最小化分类损失。

在学生模型的训练阶段,采用了一种基于蒸馏的方法。具体来说,教师模型训练完成后,生成的理想嵌入作为中间监督指导学生模型的训练。学生模型共享教师模型的分类头,通过最小化分类损失和蒸馏损失来继承教师模型的知识,生成有利于图级任务的节点嵌入。

在整个框架中,标签关注编码器起到了关键作用。它由标签编码器和多个注意力机制层组成,通过将标签嵌入和节点嵌入进行特征融合,捕捉两者之间复杂的关系,从而增强模型的表达能力。在实际操作中,标签编码器使用多层感知器(MLP)将标签编码成潜在嵌入,随后通过类似Transformer架构的注意力机制进行处理,形成高级的潜在表示。

技术细节

标签注意蒸馏方法

教师模型使用标签注意编码器,将节点特征与标签信息结合,生成理想的嵌入表示。

学生模型通过蒸馏学习,从教师模型中学习类友好的节点嵌入表示,以优化图分类任务。

方法流程

1. 标签注意教师训练:通过标签注意编码器,将节点特征与标签信息融合,生成理想的嵌入表示,并进行图分类训练。

2. 蒸馏学生学习:学生模型通过蒸馏学习,从教师模型的理想嵌入表示中学习,生成类友好的节点嵌入表示,以提升图分类性能。

算法伪代码

实验结果

论文通过在10个基准数据集上的实验验证了 LAD-GNN 的有效性。结果表明,与现有的最先进GNN方法相比,LAD-GNN 显著提高了图分类的准确性。例如,在 IMDB-BINARY 数据集上,LAD-GNN 使用 GraphSAGE 骨干网实现了高达16.8%的准确性提升。

MUTAG 教师训练

MUTAG 学生训练

这个结果比许多单独使用GNN训练的结果都更好

代码运行

环境配置

这里是所需要的相关python包

  1. - python==3.9.13
  2. - pytorch==1.12.1
  3. - dgl==1.0.2+cu116
  4. - ogb==1.3.5
  5. - sklearn
  6. - torch-cluster==1.6.0
  7. - torch-scatter==2.0.9
  8. - torch-sparse==0.6.15
  9. - torch-geometric==2.1.0
  10. - torchvision==0.13.1

数据集

在主目录下创建一个新的文件夹/data,当您运行代码时,它将自动下载相应的数据集。

运行代码

运行模型很简单,只需要下面两行命令,第一个是先运行教师模型,数据集可以根据数据名称在–dataset MUTAG这里更改,然后还有seed,一般情况下需要使用10个不同的seed进行训练,然后取平均值,数据集不需要自己下载,会自己联网下载,运行过程中请不要使用科技,否则下载会失败。

1. 使用标签注意编码器运行教师模型:

python main.py --dataset MUTAG --train_mode T --device 0 --seed 1 --nhid 64 --nlayers 2 --lr 0.01 --backbone GCN

2. 老师模型训练完成之后使用该命令进行学生模型训练:

python main.py --dataset MUTAG --train_mode S --device 0 --seed 1 --nhid 64 --nlayers 2 --lr 0.001 --backbone GCN

代码目录

  1. LAD-GNN/
  2. ├── Figures/ # 图片目录
  3. │ ├── motivation_fig.jpg # 动机示意图
  4. │ ├── framework.jpg # 整体框架图
  5. │ ├── dataset.jpg # 数据集示意图
  6. │ └── result.jpg # 结果示意图
  7. ├── GNN_models/ # 存放不同的图神经网络模型
  8. │ ├── base_model.py
  9. │ ├── gat.py # 图注意力网络模型
  10. │ ├── gcn.py # 图卷积网络模型
  11. │ ├── gin.py # 图同构网络模型
  12. │ ├── pna.py # 物理网络嵌入模型
  13. │ └── sage.py # 子图聚合增强网络模型
  14. ├── checkpoints/ # 模型检查点目录
  15. │ └── GCN/ # GCN模型的检查点
  16. ├── data/ # 数据集目录
  17. │ └── MUTAG/ # 包含MUTAG数据集的子目录
  18. │ ├── MUTAG
  19. │ ├── processed
  20. │ └── raw
  21. ├── README.md # 项目说明文件
  22. ├── main.py # 主要的Python脚本,用于执行模型训练和测试
  23. ├── test.py # 用于测试模型性能的脚本
  24. ├── requirements.txt # 项目依赖文件
  25. └── utils.py # 包含一些辅助函数的脚本

感觉不错,点击我,立即使用

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/916294
推荐阅读
相关标签
  

闽ICP备14008679号