当前位置:   article > 正文

论文阅读&STS-HGCN-AL代码详解(二)_hgcn代码

hgcn代码

论文原文:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction

项目地址:https://github.com/YangLibuaa/STSHGCN-AL

论文地址:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction | IEEE Journals & Magazine | IEEE Xplore

Github项目中已经附有论文原文,且图片分辨率比IEEE上的更高,推荐直接下载Github项目。

上期内容,本期文章继续结合论文原文,对STS-HGCN-AL的源码进行讲解。

I. GATENet

先放一张原文中GATENet的结构图:

图1  GATENet结构图

GATENet的实现非常简单。对于矩阵A,首先将其展平成一维向量,随后经过两个全连接层即可,这两个全连接层的激活函数分别是ELU和Tanh,随后经过ReLU保留非负元素,再经过Reshape操作恢复原尺寸。代码如下:

  1. class GATENet(nn.Module):
  2. def __init__(self, inc, reduction_ratio = 128):
  3. super(GATENet, self).__init__()
  4. self.fc = nn.Sequential(nn.Linear(inc, inc // reduction_ratio, bias = False),
  5. nn.ELU(inplace = False),
  6. nn.Linear(inc // reduction_ratio, inc, bias = False),
  7. nn.Tanh(),
  8. nn.ReLU(inplace = False))
  9. def forward(self, x):
  10. y = self.fc(x)
  11. return y
  1. x = self.ST_SENet(x)
  2. A_ds = self.GATENet(self.A)
  3. A_ds = A_ds.reshape(self.chan_num, self.chan_num)

代码也是非常的简洁易懂。首先使用ST_SENet生成每个频段的ICs响应,其维度为128*320*1*19。第二行和第三行首先计算A_ds,也就是频内图的邻接矩阵。

唯一需要注意的是reduction_ratio这个参数,它的作用是通过一个Linear层,将展平后的A映射到一个低维空间,再通过另一个Linear层复原到原尺寸,类似编码器-解码器的结构。这里inc=19*19,所以中间层只有19*19//128=2个神经元。

II. HGCN

图2  HGCN结构

接着往下看代码,使用torch.einsum函数计算Einstein约定求和:

L = torch.einsum('ik,kp->ip', (A_ds, torch.diag(torch.reciprocal(sum(A_ds)))))

这个式子我们逐条分析:

sum()是按列求和函数,

torch.reciprocal()是对输入的Tensor取倒数,

torch.diag()当输入为向量时,返回以该向量中的元素为主对角线的对角阵,

torch.einsum()计算Einstein约定求和,例如计算'ik,kp->ip',可表示为:

C_p^i = A_k^iB_p^k

也就是计算矩阵A和B的乘积。

这样分析下来,此行代码与下列代码等价:

L = torch.mm(A_ds, torch.diag(torch.reciprocal(sum(A_ds))))

感兴趣的读者可自行验证,二者计算结果相同。

所以这里是计算D^{-1} A_{d s}(此处存疑,按照代码的计算方式,这是在计算A_{d s}D^{-1},暂且认为是代码错误)

根据图2,HGCN的结构,发现还存在rhythmAtt-A和rhythmAtt-G两个模块。代码首先计算了rhythmAtt-A。

1. rhythmAtt-A

图3  rhythmAtt-A结构
  1. m = self.rhythmAtt_A(torch.cat((self.m1.weight,
  2. self.m2.weight,
  3. self.m3.weight,
  4. self.m4.weight,
  5. self.m5.weight), 0))

如图3所示,rhythmAtt-A以一组池化向量M为输入,对应于m1~m5,它们的作用是将输入按通道进行池化,即输出尺寸为19*1的向量。查看m1~m5的实现:

  1. self.m1 = nn.Linear(chan_num, 1)
  2. self.m2 = nn.Linear(chan_num, 1)
  3. self.m3 = nn.Linear(chan_num, 1)
  4. self.m4 = nn.Linear(chan_num, 1)
  5. self.m5 = nn.Linear(chan_num, 1)

很显然,这就是一组线性层,因此m1~m5的weight尺寸为1*19,将其拼接起来的尺寸就是5*19。

进入rhythmAtt_A的具体实现:

  1. class rhythmAtt_A(nn.Module):
  2. def __init__(self, inc, scaling_factor = 16):
  3. super(rhythmAtt_A, self).__init__()
  4. self.bn = nn.BatchNorm1d(5)
  5. self.scaling_factor = scaling_factor
  6. self.projection1 = nn.Linear(inc, scaling_factor, bias = False)
  7. self.projection2 = nn.Linear(inc, scaling_factor, bias = False)
  8. def forward(self, x):
  9. x_norm = self.bn(x.transpose(1, 0)).transpose(1, 0)
  10. x1, x2 = self.projection1(x_norm), self.projection2(x_norm)
  11. L = torch.einsum('ij,jk->ik', (x1, x2.transpose(1,0)))/np.sqrt(self.scaling_factor)
  12. L = torch.softmax(L, -1)
  13. return L

按照原文,m1~m5首先需要进行归一化处理。代码很巧妙地运用了BatchNorm1d操作,对矩阵m进行归一化操作。由于m的尺寸是5*19,而BatchNorm1d是沿着第1维进行归一化,所以需要先进行一下转置,归一化结束后再转置回来,这就是forward第一行的含义。

再来看一下原文中rhythmAtt-A的公式:

L=\operatorname{softmax}\left(\frac{1}{\sqrt{d}}\left(\theta_1 \times M\right)^T \times\left(\theta_2 \times M\right)\right)

其中M是拼接起来的归一化后的m1~m5。那么到这里就一目了然了,M经过两个投影矩阵后,经过矩阵乘法操作、放缩、softmax,起到了类似注意力机制的作用。最终得到的L尺寸是5*5,第(i, j)个元素表示第j个频段对第i个频段的影响程度,大家可以自行验证m.shape。

2. rhythmAtt-G

首先是rhythmAtt-G的结构:

图4  rhythmAtt-G结构

rhythmAtt-G以rhythmAtt-A得到的L矩阵和ST-SENet的ICs为输入。

继续看forward部分代码:

imp = m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3).expand_as(torch.zeros(5, 5*self.dim, 1, self.chan_num))

非常长,但不要慌,我们从头分析各个变量的尺寸:

  1. m: 5*5
  2. m.unsqueeze(2): 5*5*1
  3. m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous(): 5*5*64
  4. m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1): 5*320*1
  5. m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3): 5*320*1*1
  6. m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3).expand_as(torch.zeros(5, 5*self.dim, 1, self.chan_num)): 5*320*1*19 

也就是说imp的大小是5*320*1*19,看到这里我们会有一个疑问,这是在计算什么?其实这是为了计算原文的公式(5)做准备:

\tilde{H}^{(i, k)}=\delta\left(\sum_{j=1}^5 L_{i, j} H^{(j, k)}+H^{(i, k)} \theta^{(i, k)}\right)

上边计算得到的L,其实只是一个频段的L矩阵。而矩阵H的维度为128*320*1*19,因此要计算L与H的逐元素乘积,就需要将L扩展成5*320*1*19的矩阵。

先看上式第二项,\theta^{(i, k)}是rhythmAtt-G的权重,通过一个Linear层实现,但本文仍然采用1*1卷积实现:

  1. self.rhythmAtt_G1 = nn.Conv2d(in_channels = dim*5,
  2. out_channels = dim*5,
  3. kernel_size = (1, 1),
  4. stride = 1,
  5. padding = (0, 0),
  6. groups = 5,
  7. bias = False)

卷积核大小为1*1,输入和输出通道数均为320,这就保证了H^{(i, k)} \theta^{(i, k)}的大小与H保持一致。

再看上式第一项,计算L和H的逐元素乘积,其代码实现:

x = self.ELU(self.bn1(sum(torch.einsum('pijk,bijk->bpijk', (imp, x)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G1(x)))

又是一个非常长的式子,我们从里至外以此解读:

temp_a = torch.einsum('pijk,bijk->bpijk', (imp, x))

由于imp的尺寸为5*320*1*19,而x的尺寸为128*320*1*19,因此temp_a的尺寸应为128*5*320*1*19。

temp_b = temp_a.split(self.dim, 2))

split函数沿着指定的维度,将数据切分成若干块,每个块的大小由第一个参数指定。

这里dim=64,因此得到的temp_b是一个tuple,该tuple中每个元素的尺寸为128*5*64*1*19,共有5个元素。

temp_c = sum(temp_b).view(s1, s2, s3, s4)

sum(temp_b)将tuple中所有元素相加,因此得到的尺寸为128*5*64*1*19,而view函数相当于reshape,(s1, s2, s3, s4)是矩阵H的维度,即128,320,1,19,所以temp_c的尺寸是128*320*1*19。

x = self.ELU(self.bn1(temp_c+self.rhythmAtt_G1(x)))

经过从里至外的分析,这行代码就十分清晰了:temp_c相当于论文中公式(5)的第一项,self.rhythmAtt_G1(x)相当于公式(5)的第二项,经过BN层和ELU激活函数,与(5)相一致。

3. resGCN

了解了rhythmAtt-A和rhythmAtt-G之后,剩下的就是HGCN的第三个组件——resGCN了。

图5  resGCN结构

先放上论文中resGCN的公式:

H^{(i, k+1)}=\delta_2\left(D^{-1} A_{d s} \delta_1\left(H^{(i, k)} W_1^{(i, k)}\right) W_2^{(i, k)}+\sum_{j=0}^{k-1} H^{(i, j)}\right)

先前我们已经计算了D^{-1} A_{d s},结合代码:

  1. class resGCN(nn.Module):
  2. def __init__(self, inc, outc):
  3. super(resGCN, self).__init__()
  4. self.GConv1 = nn.Conv2d(in_channels = inc,
  5. out_channels = outc,
  6. kernel_size = (1, 1),
  7. stride = (1, 1),
  8. padding = (0, 0),
  9. groups = 5,
  10. bias = False)
  11. self.bn1 = nn.BatchNorm2d(outc)
  12. self.GConv2 = nn.Conv2d(in_channels = outc,
  13. out_channels = outc,
  14. kernel_size = (1, 1),
  15. stride = (1, 1),
  16. padding = (0, 0),
  17. groups = 5,
  18. bias = False)
  19. self.bn2 = nn.BatchNorm2d(outc)
  20. self.ELU = nn.ELU(inplace = False)
  21. self.initialize()
  22. def initialize(self):
  23. for m in self.modules():
  24. if isinstance(m, nn.Conv2d):
  25. nn.init.xavier_uniform_(m.weight, gain = 1)
  26. elif isinstance(m, nn.BatchNorm2d):
  27. nn.init.constant_(m.weight, 1)
  28. nn.init.constant_(m.bias, 0)
  29. def forward(self, x, x_p, L):
  30. x = self.bn2(self.GConv2(self.ELU(self.bn1(self.GConv1(x)))))
  31. y = torch.einsum('bijk,kp->bijp', (x, L))
  32. y = self.ELU(torch.add(y, x_p))
  33. return y

参考STS_HGCN的init部分:

  1. self.resGCN1 = resGCN(inc = dim*5,
  2. outc = dim*5)
  3. self.resGCN2 = resGCN(inc = dim*5,
  4. outc = dim*5)
  5. self.resGCN3 = resGCN(inc = dim*5,
  6. outc = dim*5)
  7. self.resGCN4 = resGCN(inc = dim*5,
  8. outc = dim*5)

可见resGCN的init部分传入的inc和outc都是64*5=320。resGCN由两个1*1卷积层组成,但group=5,这决定了输入卷积通道与输出卷积通道被分成5组,仅组内可见。如图6所示,分成5组是因为每组都对应于一个频段的intrarhythm graph。

图6  层级图与resGCN的关系

resGCN的forward部分是严格按照图5进行的。

在forward的第一行计算temp_d = \delta_1 (H^{(i, k)} W_1^{(i, k)}) W_2^{(i, k)},第二行计算temp_e = D^{-1} A_{d s} temp_d。由于temp_d的尺寸为128*320*1*19,故temp_e的尺寸也为128*320*1*19。注意,这两行实际上执行了图卷积的前向传播。其中H是节点特征,A是节点邻接矩阵。

forward的第三行执行残差块连接,将当前层的计算结果temp_e与之前所有层的输出进行累加,最后使用ELU激活函数,得到当前层的输出。

参考STS_HGCN的forward部分,可知G1,G2,G3,G4代表4层resGCN,而resGCN的第二个输入参数则是之前所有层的输出之和,与文章所述内容一致:

  1. x = self.ELU(self.bn1(sum(torch.einsum('pijk,bijk->bpijk', (imp, x)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G1(x)))
  2. G1 = self.resGCN1(x, x, L).contiguous()
  3. G1 = self.ELU(self.bn2(sum(torch.einsum('pijk,bijk->bpijk', (imp, G1)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G2(G1)))
  4. G2 = self.resGCN2(G1, torch.add(x, G1), L).contiguous()
  5. G2 = self.ELU(self.bn3(sum(torch.einsum('pijk,bijk->bpijk', (imp, G2)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G3(G2)))
  6. G3 = self.resGCN3(G2, torch.add(torch.add(x, G1), G2), L).contiguous()
  7. G3 = self.ELU(self.bn4(sum(torch.einsum('pijk,bijk->bpijk', (imp, G3)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G4(G3)))
  8. G4 = self.resGCN4(G3, torch.add(torch.add(torch.add(x, G1), G2), G3), L).contiguous()
  9. G4 = self.ELU(self.bn5(sum(torch.einsum('pijk,bijk->bpijk', (imp, G4)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G5(G4)))

按照图2所示,经过5层rhythmAtt-G和4层resGCN后,五个频段的计算结果被单独分开:

A, B, C, D, E = G4.split(self.dim, 1)

由于G4的尺寸仍为128*320*1*19,因此得到的A,B,C,D,E各为128*64*1*19大小。同时由于组卷积的原因,各个频段之间并未混淆。

4. 节点池化

随后是节点池化。A,B,C,D,E即为各个频段的节点特征,对其沿着EEG通道维池化,即期望输出大小应为128*64*1*1。下面看forward部分的具体操作:

  1. y = torch.cat((self.m1(A.view(A.size(0), A.size(1), -1)).unsqueeze(-1).contiguous(),
  2. self.m2(B.view(B.size(0), B.size(1), -1)).unsqueeze(-1).contiguous(),
  3. self.m3(C.view(C.size(0), C.size(1), -1)).unsqueeze(-1).contiguous(),
  4. self.m4(D.view(D.size(0), D.size(1), -1)).unsqueeze(-1).contiguous(),
  5. self.m5(E.view(E.size(0), E.size(1), -1)).unsqueeze(-1).contiguous()), 1)

A.view(A.size(0), A.size(1), -1)这一步是将A变成128*64*19的尺寸,使其可以与1*19的池化向量相乘,所以self.m1(A.view(A.size(0), A.size(1), -1))的尺寸是128*64*1,对其进行unsqueeze(-1)操作,尺寸变成128*64*1*1,达到期望输出。最后将五个池化层输出沿第1维串联,得到128*320*1*1的输出。

5. 分类层

图7  分类层结构

最后将池化输出送入分类器进行分类。分类器的定义为:

  1. class classification_net(nn.Module):
  2. def __init__(self, inc, tmp, outc):
  3. super(classification_net, self).__init__()
  4. self.conv1 = nn.Conv2d(in_channels = inc*5,
  5. out_channels = tmp,
  6. kernel_size = (1, 1),
  7. stride = 1,
  8. padding = (0, 0),
  9. bias = True)
  10. self.conv2 = nn.Conv2d(in_channels = tmp,
  11. out_channels = outc,
  12. kernel_size = (1, 1),
  13. stride = 1,
  14. padding = (0, 0),
  15. bias = True)
  16. self.initialize()
  17. def initialize(self):
  18. for m in self.modules():
  19. if isinstance(m, nn.Conv2d):
  20. nn.init.xavier_uniform_(m.weight, gain = 1)
  21. elif isinstance(m, nn.BatchNorm2d):
  22. nn.init.constant_(m.weight, 1)
  23. nn.init.constant_(m.bias, 0)
  24. def forward(self, x):
  25. y = self.conv2(self.conv1(x))
  26. return y

按照图7所示,分类器由2个线性层构成,这里使用1*1卷积层代替线性层。第1个线性层的输入维度为64*5=320,输出维度为32,第2个线性层的输入维度为32,输出维度为2,实现预测。其实个人认为应该在第2个线性层后面加上softmax激活函数,就完美了,不知作者是否忽略了这一点。

由于输出y的尺寸为128*2*1*1,使用squeeze()值为1的维度进行压缩,最终得到128*2的预测标签:

pred = self.classification_net(y).squeeze()

至此,关于STS-HGCN的部分已经全部介绍完毕。关于论文中的另一个重要部分——主动学习(AL)则没有对应的代码,因此也无法进一步深挖了。但无论是ST-SENet,还是STS-HGCN,在设计思路上都有很多可借鉴之处,同时通过阅读代码,加深了对各个组件底层实现的理解。

论文常读常新,每次阅读都会有不一样的感觉,希望大家能融会贯通,举一反三,在此基础上设计出更优秀的模型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/833750
推荐阅读
相关标签
  

闽ICP备14008679号