当前位置:   article > 正文

图神经网络学习4_transformer_conv

transformer_conv

节点的预测任务

首先定义图神经网络的网络结构,这里使用了torch_geometric.nn.Sequential容器,详细内容可见于官方文档。我们通过hidden_channels_list参数来设置每一层GATConvoutchannel,通过修改hidden_channels_list,我们就可构造出不同的图神经网络。这里通过了三种方式改变网络结构

  • 1、使用不同的卷积层
  • 2、使用不同的层数
  • 3、每层的不同的神经元个数

1、通过不同的卷积层
这里使用了GCNConv、GATConv、SAGEConv、GraphConv和TransformerConv五种卷积核,运行效果如下:

# 先定义数据
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('======================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
####### 定义GCN网络结构
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels_list, num_classes):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        hns = [num_features] + hidden_channels_list
        conv_list = []
        for idx in range(len(hidden_channels_list)):
            conv_list.append((GCNConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
            conv_list.append(ReLU(inplace=True),)

        self.convseq = Sequential('x, edge_index', conv_list)
        self.linear = Linear(hidden_channels_list[-1], num_classes)

    def forward(self, x, edge_index):
        x = self.convseq(x, edge_index)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

其他网络层结构与上面类似,最终实验结果如图所示
在这里插入图片描述

2、使用不同层数

  • hidden_channels_list=[2]
  • hidden_channels_list=[100]
  • hidden_channels_list=[200, 100]
  • hidden_channels_list=[200, 100, 50]
  • hidden_channels_list=[500, 200, 100]
  • hidden_channels_list=[500, 200, 100, 50]
layers = [[2],
         [100],
         [200, 100],
         [200, 100, 50],
         [500, 200, 100],
         [500, 200, 100, 50]]
def test_case(layer):
    model = GCN(num_features=dataset.num_features, hidden_channels_list=layer, num_classes=dataset.num_classes).to(device)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, 301):
        loss = train()
    test_acc = test()
    print(f'Test Accuracy: {test_acc:.4f}')

for layer in layers:
    test_case(layer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述
3、每层使用不同的神经元

  • hidden_channels_list=[2]
  • hidden_channels_list=[20]
  • hidden_channels_list=[200]
layers = [[2],
         [20],
         [200]]
def test_case(layer):
    model = GCN(num_features=dataset.num_features, hidden_channels_list=layer, num_classes=dataset.num_classes).to(device)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, 101):
        loss = train()
    test_acc = test()
    print(f'Test Accuracy: {test_acc:.4f}')

for layer in layers:
    test_case(layer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

在这里插入图片描述

边的预测任务

边预测任务,目标是预测两个节点之间是否存在边。拿到一个图数据集,我们有节点属性x,边端点edge_indexedge_index存储的便是正样本。为了构建边预测任务,我们需要生成一些负样本,即采样一些不存在边的节点对作为负样本边,正负样本数量应平衡。此外要将样本分为训练集、验证集和测试集三个集合。

PyG中为我们提供了现成的采样负样本边的方法,train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1),其

  • 第一个参数为torch_geometric.data.Data对象,
  • 第二参数为验证集所占比例,
  • 第三个参数为测试集所占比例。
import os.path as osp

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from sklearn.metrics import roc_auc_score
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling, train_test_split_edges

###### 用Squential实现
# class GAT(torch.nn.Module):
#     def __init__(self, num_features, hidden_channels_list, num_classes):
#         super(GAT, self).__init__()
#         torch.manual_seed(12345)
#         hns = [num_features] + hidden_channels_list
#         conv_list = []
#         for idx in range(len(hidden_channels_list)):
#             conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
#             conv_list.append(ReLU(inplace=True),)

#         self.convseq = Sequential('x, edge_index', conv_list)
#         self.linear = Linear(hidden_channels_list[-1], num_classes)

#     def forward(self, x, edge_index):
#         x = self.convseq(x, edge_index)
#         x = F.dropout(x, p=0.5, training=self.training)
#         x = self.linear(x)
#         return x

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels_list, out_channels):
        super(Net, self).__init__()
        hns = [in_channels] + hidden_channels_list
        conv_list = []
        for idx in range(len(hidden_channels_list)):
            conv_list.append((GCNConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
            conv_list.append(ReLU(inplace=True),)
        self.convseq = Sequential('x, edge_index', conv_list)
        self.linear = Linear(hidden_channels_list[-1], out_channels)
        
    def encode(self, x, edge_index):
        x = self.convseq(x, edge_index)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear(x)
        return x

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


def get_link_labels(pos_edge_index, neg_edge_index):
    num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(num_links, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels


def train(data, model, optimizer):
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1))
    
    train_neg_edge_set = set(map(tuple, neg_edge_index.T.tolist()))
    val_pos_edge_set = set(map(tuple, data.val_pos_edge_index.T.tolist()))
    test_pos_edge_set = set(map(tuple, data.test_pos_edge_index.T.tolist()))
    if (len(train_neg_edge_set & val_pos_edge_set) > 0) or (len(train_neg_edge_set & test_pos_edge_set) > 0):
        print('wrong!')

    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test(data, model):
    model.eval()

    z = model.encode(data.x, data.train_pos_edge_index)

    results = []
    for prefix in ['val', 'test']:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']
        link_logits = model.decode(z, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
    return results

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    dataset = 'Cora'
    path = './dataset'
    dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
    data = dataset[0]
    
#     dataset = 'Cora'
#     path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
#     dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
#     data = dataset[0]
    
    
    
    ground_truth_edge_index = data.edge_index.to(device)
    data.train_mask = data.val_mask = data.test_mask = data.y = None
    data = train_test_split_edges(data)
    data = data.to(device)

    model = Net(dataset.num_features, [256, 128], 64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

    best_val_auc = test_auc = 0
    for epoch in range(1, 3001):
        loss = train(data, model, optimizer)
        val_auc, tmp_test_auc = test(data, model)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            test_auc = tmp_test_auc
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
              f'Test: {test_auc:.4f}')

    z = model.encode(data.x, data.train_pos_edge_index)
    final_edge_index = model.decode_all(z)


if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143

运行结果(这里出现wrong的原因是存在训练集负样本与验证集负样本存在交集,或训练集负样本与测试集负样本存在交集):
在这里插入图片描述
本文内容主要来自datawhale开源课程

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/488338
推荐阅读
相关标签
  

闽ICP备14008679号