当前位置:   article > 正文

【教程】使用 Captum 解释 GNN 模型预测_captum.attr

captum.attr

转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn]

Colab Notebook

安装必须的库:

  1. # Install required packages.
  2. import os
  3. import torch
  4. os.environ['TORCH'] = torch.__version__
  5. print(torch.__version__)
  6. !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
  7. !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
  8. !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
  9. !pip install -q captum
  10. # Helper function for visualization.
  11. %matplotlib inline
  12. import matplotlib.pyplot as plt
  1. 1.13.1+cu116
  2. Installing build dependencies ... done
  3. Getting requirements to build wheel ... done
  4. Preparing metadata (pyproject.toml) ... done

用Captum解释GNN模型的预测

        在本教程中,我们演示了如何将特征归属方法应用于图。具体来说,我们试图找到对每个实例预测最重要的边。

        我们使用TUDatasets的诱变性数据集。这个数据集由4337个分子图组成,任务是预测分子的诱变性。

加载数据集

        我们加载数据集并使用10%的数据作为测试分割。

  1. from torch_geometric.loader import DataLoader
  2. from torch_geometric.datasets import TUDataset
  3. path = '.'
  4. dataset = TUDataset(path, name='Mutagenicity').shuffle()
  5. test_dataset = dataset[:len(dataset) // 10]
  6. train_dataset = dataset[len(dataset) // 10:]
  7. test_loader = DataLoader(test_dataset, batch_size=128)
  8. train_loader = DataLoader(train_dataset, batch_size=128)
  1. Downloading https://www.chrsmrrs.com/graphkerneldatasets/Mutagenicity.zip
  2. Extracting ./Mutagenicity/Mutagenicity.zip
  3. Processing...
  4. Done!

数据的可视化

        我们定义了一些用于可视化分子的效用函数,并随机抽取一个分子。

  1. import networkx as nx
  2. import numpy as np
  3. from torch_geometric.utils import to_networkx
  4. def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
  5. g = g.copy().to_undirected()
  6. node_labels = {}
  7. for u, data in g.nodes(data=True):
  8. node_labels[u] = data['name']
  9. pos = nx.planar_layout(g)
  10. pos = nx.spring_layout(g, pos=pos)
  11. if edge_mask is None:
  12. edge_color = 'black'
  13. widths = None
  14. else:
  15. edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
  16. widths = [x * 10 for x in edge_color]
  17. nx.draw(g, pos=pos, labels=node_labels, width=widths,
  18. edge_color=edge_color, edge_cmap=plt.cm.Blues,
  19. node_color='azure')
  20. if draw_edge_labels and edge_mask is not None:
  21. edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}
  22. nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
  23. font_color='red')
  24. plt.show()
  25. def to_molecule(data):
  26. ATOM_MAP = ['C', 'O', 'Cl', 'H', 'N', 'F',
  27. 'Br', 'S', 'P', 'I', 'Na', 'K', 'Li', 'Ca']
  28. g = to_networkx(data, node_attrs=['x'])
  29. for u, data in g.nodes(data=True):
  30. data['name'] = ATOM_MAP[data['x'].index(1.0)]
  31. del data['x']
  32. return g

采样的可视化

        我们从train_dataset中抽出一个单分子并将其可视化

  1. import random
  2. data = random.choice([t for t in train_dataset])
  3. mol = to_molecule(data)
  4. plt.figure(figsize=(10, 5))
  5. draw_molecule(mol)

训练模型

        在下一节中,我们训练一个具有5个卷积层的GNN模型。我们使用GraphConv,它支持edge_weight作为一个参数。Pytorch Geometric的许多卷积层都支持这个参数。

定义模型

  1. import torch
  2. from torch.nn import Linear
  3. import torch.nn.functional as F
  4. from torch_geometric.nn import GraphConv, global_add_pool
  5. class Net(torch.nn.Module):
  6. def __init__(self, dim):
  7. super(Net, self).__init__()
  8. num_features = dataset.num_features
  9. self.dim = dim
  10. self.conv1 = GraphConv(num_features, dim)
  11. self.conv2 = GraphConv(dim, dim)
  12. self.conv3 = GraphConv(dim, dim)
  13. self.conv4 = GraphConv(dim, dim)
  14. self.conv5 = GraphConv(dim, dim)
  15. self.lin1 = Linear(dim, dim)
  16. self.lin2 = Linear(dim, dataset.num_classes)
  17. def forward(self, x, edge_index, batch, edge_weight=None):
  18. x = self.conv1(x, edge_index, edge_weight).relu()
  19. x = self.conv2(x, edge_index, edge_weight).relu()
  20. x = self.conv3(x, edge_index, edge_weight).relu()
  21. x = self.conv4(x, edge_index, edge_weight).relu()
  22. x = self.conv5(x, edge_index, edge_weight).relu()
  23. x = global_add_pool(x, batch)
  24. x = self.lin1(x).relu()
  25. x = F.dropout(x, p=0.5, training=self.training)
  26. x = self.lin2(x)
  27. return F.log_softmax(x, dim=-1)

定义训练和测试函数

  1. def train(epoch):
  2. model.train()
  3. if epoch == 51:
  4. for param_group in optimizer.param_groups:
  5. param_group['lr'] = 0.5 * param_group['lr']
  6. loss_all = 0
  7. for data in train_loader:
  8. data = data.to(device)
  9. optimizer.zero_grad()
  10. output = model(data.x, data.edge_index, data.batch)
  11. loss = F.nll_loss(output, data.y)
  12. loss.backward()
  13. loss_all += loss.item() * data.num_graphs
  14. optimizer.step()
  15. return loss_all / len(train_dataset)
  16. def test(loader):
  17. model.eval()
  18. correct = 0
  19. for data in loader:
  20. data = data.to(device)
  21. output = model(data.x, data.edge_index, data.batch)
  22. pred = output.max(dim=1)[1]
  23. correct += pred.eq(data.y).sum().item()
  24. return correct / len(loader.dataset)
'
运行

对模型进行100次的训练

        最后的准确率应该在80%左右

  1. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  2. model = Net(dim=32).to(device)
  3. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  4. for epoch in range(1, 101):
  5. loss = train(epoch)
  6. train_acc = test(train_loader)
  7. test_acc = test(test_loader)
  8. print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
  9. f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
  1. Epoch: 090, Loss: 0.2992, Train Acc: 0.8824, Test Acc: 0.7968
  2. Epoch: 091, Loss: 0.3000, Train Acc: 0.8855, Test Acc: 0.8060
  3. Epoch: 092, Loss: 0.3129, Train Acc: 0.8832, Test Acc: 0.8037
  4. Epoch: 093, Loss: 0.3056, Train Acc: 0.8791, Test Acc: 0.8129
  5. Epoch: 094, Loss: 0.2947, Train Acc: 0.8835, Test Acc: 0.8014
  6. Epoch: 095, Loss: 0.2949, Train Acc: 0.8758, Test Acc: 0.8129
  7. Epoch: 096, Loss: 0.2946, Train Acc: 0.8791, Test Acc: 0.8060
  8. Epoch: 097, Loss: 0.2989, Train Acc: 0.8768, Test Acc: 0.8083
  9. Epoch: 098, Loss: 0.2946, Train Acc: 0.8822, Test Acc: 0.7968
  10. Epoch: 099, Loss: 0.2908, Train Acc: 0.8835, Test Acc: 0.8060
  11. Epoch: 100, Loss: 0.2910, Train Acc: 0.8840, Test Acc: 0.8037

解释预测结果

        现在我们看一下两种流行的归因方法。首先,我们计算输出相对于边缘权重的梯度 wei 。边缘权重最初对所有的边缘都是一。对于显著性方法,我们使用梯度的绝对值作为每个边缘的归属值。

        其中x是输入,F(x)是GNN模型对输入x的输出。

        对于综合梯度法,我们在当前输入和基线输入之间进行插值,其中所有边缘的权重为零,并累积每条边缘的梯度值。

 

        其中xα与原始输入图相同,但所有边的权重被设置为α。综合梯度的完整表述比较复杂,但由于我们的初始边权重等于1,基线为0,所以可以简化为上述表述。你可以在这里阅读更多关于这个方法的信息。当然,这不能直接计算,而是用一个离散的总和来近似。

        我们使用captum库来计算归因值。我们定义了model_forward函数,假设我们一次只解释一个图形,它就会计算出批量参数。

  1. from captum.attr import Saliency, IntegratedGradients
  2. def model_forward(edge_mask, data):
  3. batch = torch.zeros(data.x.shape[0], dtype=int).to(device)
  4. out = model(data.x, data.edge_index, batch, edge_mask)
  5. return out
  6. def explain(method, data, target=0):
  7. input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(device)
  8. if method == 'ig':
  9. ig = IntegratedGradients(model_forward)
  10. mask = ig.attribute(input_mask, target=target,
  11. additional_forward_args=(data,),
  12. internal_batch_size=data.edge_index.shape[1])
  13. elif method == 'saliency':
  14. saliency = Saliency(model_forward)
  15. mask = saliency.attribute(input_mask, target=target,
  16. additional_forward_args=(data,))
  17. else:
  18. raise Exception('Unknown explanation method')
  19. edge_mask = np.abs(mask.cpu().detach().numpy())
  20. if edge_mask.max() > 0: # avoid division by zero
  21. edge_mask = edge_mask / edge_mask.max()
  22. return edge_mask

        最后我们从测试数据集中随机抽取一个样本,运行解释方法。为了更简单的可视化,我们使图形无定向,并合并每个边缘在两个方向上的解释。

        众所周知,在许多情况下,NO2的子结构使分子具有诱变性,你可以通过模型的解释来验证这一点。

        在这个数据集中,诱变分子的标签为0,我们只从这些分子中取样,但你可以改变代码,也可以看到其他类别的解释。

        在这个可视化中,边缘的颜色和厚度代表了重要性。你也可以通过向draw_molecule函数传递draw_edge_labels来查看数值。

        正如你所看到的,综合梯度往往能创造出更准确的解释。

  1. import random
  2. from collections import defaultdict
  3. def aggregate_edge_directions(edge_mask, data):
  4. edge_mask_dict = defaultdict(float)
  5. for val, u, v in list(zip(edge_mask, *data.edge_index)):
  6. u, v = u.item(), v.item()
  7. if u > v:
  8. u, v = v, u
  9. edge_mask_dict[(u, v)] += val
  10. return edge_mask_dict
  11. data = random.choice([t for t in test_dataset if not t.y.item()])
  12. mol = to_molecule(data)
  13. for title, method in [('Integrated Gradients', 'ig'), ('Saliency', 'saliency')]:
  14. edge_mask = explain(method, data, target=0)
  15. edge_mask_dict = aggregate_edge_directions(edge_mask, data)
  16. plt.figure(figsize=(10, 5))
  17. plt.title(title)
  18. draw_molecule(mol, edge_mask_dict)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/888566
推荐阅读
相关标签
  

闽ICP备14008679号