赞
踩
论文原文:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction
项目地址:https://github.com/YangLibuaa/STSHGCN-AL
Github项目中已经附有论文原文,且图片分辨率比IEEE上的更高,推荐直接下载Github项目。
接上期内容,本期文章继续结合论文原文,对STS-HGCN-AL的源码进行讲解。
先放一张原文中GATENet的结构图:
GATENet的实现非常简单。对于矩阵A,首先将其展平成一维向量,随后经过两个全连接层即可,这两个全连接层的激活函数分别是ELU和Tanh,随后经过ReLU保留非负元素,再经过Reshape操作恢复原尺寸。代码如下:
- class GATENet(nn.Module):
- def __init__(self, inc, reduction_ratio = 128):
- super(GATENet, self).__init__()
- self.fc = nn.Sequential(nn.Linear(inc, inc // reduction_ratio, bias = False),
- nn.ELU(inplace = False),
- nn.Linear(inc // reduction_ratio, inc, bias = False),
- nn.Tanh(),
- nn.ReLU(inplace = False))
-
- def forward(self, x):
- y = self.fc(x)
- return y
- x = self.ST_SENet(x)
- A_ds = self.GATENet(self.A)
- A_ds = A_ds.reshape(self.chan_num, self.chan_num)
代码也是非常的简洁易懂。首先使用ST_SENet生成每个频段的ICs响应,其维度为128*320*1*19。第二行和第三行首先计算A_ds,也就是频内图的邻接矩阵。
唯一需要注意的是reduction_ratio这个参数,它的作用是通过一个Linear层,将展平后的A映射到一个低维空间,再通过另一个Linear层复原到原尺寸,类似编码器-解码器的结构。这里inc=19*19,所以中间层只有19*19//128=2个神经元。
接着往下看代码,使用torch.einsum函数计算Einstein约定求和:
L = torch.einsum('ik,kp->ip', (A_ds, torch.diag(torch.reciprocal(sum(A_ds)))))
这个式子我们逐条分析:
sum()是按列求和函数,
torch.reciprocal()是对输入的Tensor取倒数,
torch.diag()当输入为向量时,返回以该向量中的元素为主对角线的对角阵,
torch.einsum()计算Einstein约定求和,例如计算'ik,kp->ip',可表示为:
也就是计算矩阵A和B的乘积。
这样分析下来,此行代码与下列代码等价:
L = torch.mm(A_ds, torch.diag(torch.reciprocal(sum(A_ds))))
感兴趣的读者可自行验证,二者计算结果相同。
所以这里是计算(此处存疑,按照代码的计算方式,这是在计算
,暂且认为是代码错误)
根据图2,HGCN的结构,发现还存在rhythmAtt-A和rhythmAtt-G两个模块。代码首先计算了rhythmAtt-A。
- m = self.rhythmAtt_A(torch.cat((self.m1.weight,
- self.m2.weight,
- self.m3.weight,
- self.m4.weight,
- self.m5.weight), 0))
如图3所示,rhythmAtt-A以一组池化向量M为输入,对应于m1~m5,它们的作用是将输入按通道进行池化,即输出尺寸为19*1的向量。查看m1~m5的实现:
- self.m1 = nn.Linear(chan_num, 1)
- self.m2 = nn.Linear(chan_num, 1)
- self.m3 = nn.Linear(chan_num, 1)
- self.m4 = nn.Linear(chan_num, 1)
- self.m5 = nn.Linear(chan_num, 1)
很显然,这就是一组线性层,因此m1~m5的weight尺寸为1*19,将其拼接起来的尺寸就是5*19。
进入rhythmAtt_A的具体实现:
- class rhythmAtt_A(nn.Module):
- def __init__(self, inc, scaling_factor = 16):
- super(rhythmAtt_A, self).__init__()
- self.bn = nn.BatchNorm1d(5)
- self.scaling_factor = scaling_factor
- self.projection1 = nn.Linear(inc, scaling_factor, bias = False)
- self.projection2 = nn.Linear(inc, scaling_factor, bias = False)
-
- def forward(self, x):
- x_norm = self.bn(x.transpose(1, 0)).transpose(1, 0)
- x1, x2 = self.projection1(x_norm), self.projection2(x_norm)
- L = torch.einsum('ij,jk->ik', (x1, x2.transpose(1,0)))/np.sqrt(self.scaling_factor)
- L = torch.softmax(L, -1)
- return L
按照原文,m1~m5首先需要进行归一化处理。代码很巧妙地运用了操作,对矩阵m进行归一化操作。由于m的尺寸是5*19,而
是沿着第1维进行归一化,所以需要先进行一下转置,归一化结束后再转置回来,这就是forward第一行的含义。
再来看一下原文中rhythmAtt-A的公式:
其中M是拼接起来的归一化后的m1~m5。那么到这里就一目了然了,M经过两个投影矩阵后,经过矩阵乘法操作、放缩、softmax,起到了类似注意力机制的作用。最终得到的L尺寸是5*5,第(i, j)个元素表示第j个频段对第i个频段的影响程度,大家可以自行验证m.shape。
首先是rhythmAtt-G的结构:
rhythmAtt-G以rhythmAtt-A得到的L矩阵和ST-SENet的ICs为输入。
继续看forward部分代码:
imp = m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3).expand_as(torch.zeros(5, 5*self.dim, 1, self.chan_num))
非常长,但不要慌,我们从头分析各个变量的尺寸:
- m: 5*5
- m.unsqueeze(2): 5*5*1
- m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous(): 5*5*64
- m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1): 5*320*1
- m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3): 5*320*1*1
- m.unsqueeze(2).expand_as(torch.zeros(5, 5, self.dim)).contiguous().view(5, -1, 1).unsqueeze(3).expand_as(torch.zeros(5, 5*self.dim, 1, self.chan_num)): 5*320*1*19
也就是说imp的大小是5*320*1*19,看到这里我们会有一个疑问,这是在计算什么?其实这是为了计算原文的公式(5)做准备:
上边计算得到的L,其实只是一个频段的L矩阵。而矩阵H的维度为128*320*1*19,因此要计算L与H的逐元素乘积,就需要将L扩展成5*320*1*19的矩阵。
先看上式第二项,是rhythmAtt-G的权重,通过一个Linear层实现,但本文仍然采用1*1卷积实现:
- self.rhythmAtt_G1 = nn.Conv2d(in_channels = dim*5,
- out_channels = dim*5,
- kernel_size = (1, 1),
- stride = 1,
- padding = (0, 0),
- groups = 5,
- bias = False)
卷积核大小为1*1,输入和输出通道数均为320,这就保证了的大小与H保持一致。
再看上式第一项,计算L和H的逐元素乘积,其代码实现:
x = self.ELU(self.bn1(sum(torch.einsum('pijk,bijk->bpijk', (imp, x)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G1(x)))
又是一个非常长的式子,我们从里至外以此解读:
temp_a = torch.einsum('pijk,bijk->bpijk', (imp, x))
由于imp的尺寸为5*320*1*19,而x的尺寸为128*320*1*19,因此temp_a的尺寸应为128*5*320*1*19。
temp_b = temp_a.split(self.dim, 2))
split函数沿着指定的维度,将数据切分成若干块,每个块的大小由第一个参数指定。
这里dim=64,因此得到的temp_b是一个tuple,该tuple中每个元素的尺寸为128*5*64*1*19,共有5个元素。
temp_c = sum(temp_b).view(s1, s2, s3, s4)
sum(temp_b)将tuple中所有元素相加,因此得到的尺寸为128*5*64*1*19,而view函数相当于reshape,(s1, s2, s3, s4)是矩阵H的维度,即128,320,1,19,所以temp_c的尺寸是128*320*1*19。
x = self.ELU(self.bn1(temp_c+self.rhythmAtt_G1(x)))
经过从里至外的分析,这行代码就十分清晰了:temp_c相当于论文中公式(5)的第一项,self.rhythmAtt_G1(x)相当于公式(5)的第二项,经过BN层和ELU激活函数,与(5)相一致。
了解了rhythmAtt-A和rhythmAtt-G之后,剩下的就是HGCN的第三个组件——resGCN了。
先放上论文中resGCN的公式:
先前我们已经计算了,结合代码:
- class resGCN(nn.Module):
- def __init__(self, inc, outc):
- super(resGCN, self).__init__()
- self.GConv1 = nn.Conv2d(in_channels = inc,
- out_channels = outc,
- kernel_size = (1, 1),
- stride = (1, 1),
- padding = (0, 0),
- groups = 5,
- bias = False)
- self.bn1 = nn.BatchNorm2d(outc)
- self.GConv2 = nn.Conv2d(in_channels = outc,
- out_channels = outc,
- kernel_size = (1, 1),
- stride = (1, 1),
- padding = (0, 0),
- groups = 5,
- bias = False)
- self.bn2 = nn.BatchNorm2d(outc)
- self.ELU = nn.ELU(inplace = False)
- self.initialize()
-
- def initialize(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_uniform_(m.weight, gain = 1)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x, x_p, L):
- x = self.bn2(self.GConv2(self.ELU(self.bn1(self.GConv1(x)))))
- y = torch.einsum('bijk,kp->bijp', (x, L))
- y = self.ELU(torch.add(y, x_p))
- return y

参考STS_HGCN的init部分:
- self.resGCN1 = resGCN(inc = dim*5,
- outc = dim*5)
- self.resGCN2 = resGCN(inc = dim*5,
- outc = dim*5)
- self.resGCN3 = resGCN(inc = dim*5,
- outc = dim*5)
- self.resGCN4 = resGCN(inc = dim*5,
- outc = dim*5)
可见resGCN的init部分传入的inc和outc都是64*5=320。resGCN由两个1*1卷积层组成,但group=5,这决定了输入卷积通道与输出卷积通道被分成5组,仅组内可见。如图6所示,分成5组是因为每组都对应于一个频段的intrarhythm graph。
resGCN的forward部分是严格按照图5进行的。
在forward的第一行计算,第二行计算
。由于
的尺寸为128*320*1*19,故
的尺寸也为128*320*1*19。注意,这两行实际上执行了图卷积的前向传播。其中H是节点特征,A是节点邻接矩阵。
forward的第三行执行残差块连接,将当前层的计算结果与之前所有层的输出进行累加,最后使用ELU激活函数,得到当前层的输出。
参考STS_HGCN的forward部分,可知G1,G2,G3,G4代表4层resGCN,而resGCN的第二个输入参数则是之前所有层的输出之和,与文章所述内容一致:
- x = self.ELU(self.bn1(sum(torch.einsum('pijk,bijk->bpijk', (imp, x)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G1(x)))
-
- G1 = self.resGCN1(x, x, L).contiguous()
- G1 = self.ELU(self.bn2(sum(torch.einsum('pijk,bijk->bpijk', (imp, G1)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G2(G1)))
-
- G2 = self.resGCN2(G1, torch.add(x, G1), L).contiguous()
- G2 = self.ELU(self.bn3(sum(torch.einsum('pijk,bijk->bpijk', (imp, G2)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G3(G2)))
-
- G3 = self.resGCN3(G2, torch.add(torch.add(x, G1), G2), L).contiguous()
- G3 = self.ELU(self.bn4(sum(torch.einsum('pijk,bijk->bpijk', (imp, G3)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G4(G3)))
-
- G4 = self.resGCN4(G3, torch.add(torch.add(torch.add(x, G1), G2), G3), L).contiguous()
- G4 = self.ELU(self.bn5(sum(torch.einsum('pijk,bijk->bpijk', (imp, G4)).split(self.dim, 2)).view(s1, s2, s3, s4)+self.rhythmAtt_G5(G4)))
按照图2所示,经过5层rhythmAtt-G和4层resGCN后,五个频段的计算结果被单独分开:
A, B, C, D, E = G4.split(self.dim, 1)
由于G4的尺寸仍为128*320*1*19,因此得到的A,B,C,D,E各为128*64*1*19大小。同时由于组卷积的原因,各个频段之间并未混淆。
随后是节点池化。A,B,C,D,E即为各个频段的节点特征,对其沿着EEG通道维池化,即期望输出大小应为128*64*1*1。下面看forward部分的具体操作:
- y = torch.cat((self.m1(A.view(A.size(0), A.size(1), -1)).unsqueeze(-1).contiguous(),
- self.m2(B.view(B.size(0), B.size(1), -1)).unsqueeze(-1).contiguous(),
- self.m3(C.view(C.size(0), C.size(1), -1)).unsqueeze(-1).contiguous(),
- self.m4(D.view(D.size(0), D.size(1), -1)).unsqueeze(-1).contiguous(),
- self.m5(E.view(E.size(0), E.size(1), -1)).unsqueeze(-1).contiguous()), 1)
A.view(A.size(0), A.size(1), -1)这一步是将A变成128*64*19的尺寸,使其可以与1*19的池化向量相乘,所以self.m1(A.view(A.size(0), A.size(1), -1))的尺寸是128*64*1,对其进行unsqueeze(-1)操作,尺寸变成128*64*1*1,达到期望输出。最后将五个池化层输出沿第1维串联,得到128*320*1*1的输出。
最后将池化输出送入分类器进行分类。分类器的定义为:
- class classification_net(nn.Module):
- def __init__(self, inc, tmp, outc):
- super(classification_net, self).__init__()
- self.conv1 = nn.Conv2d(in_channels = inc*5,
- out_channels = tmp,
- kernel_size = (1, 1),
- stride = 1,
- padding = (0, 0),
- bias = True)
- self.conv2 = nn.Conv2d(in_channels = tmp,
- out_channels = outc,
- kernel_size = (1, 1),
- stride = 1,
- padding = (0, 0),
- bias = True)
- self.initialize()
-
- def initialize(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_uniform_(m.weight, gain = 1)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x):
- y = self.conv2(self.conv1(x))
- return y

按照图7所示,分类器由2个线性层构成,这里使用1*1卷积层代替线性层。第1个线性层的输入维度为64*5=320,输出维度为32,第2个线性层的输入维度为32,输出维度为2,实现预测。其实个人认为应该在第2个线性层后面加上softmax激活函数,就完美了,不知作者是否忽略了这一点。
由于输出y的尺寸为128*2*1*1,使用squeeze()值为1的维度进行压缩,最终得到128*2的预测标签:
pred = self.classification_net(y).squeeze()
至此,关于STS-HGCN的部分已经全部介绍完毕。关于论文中的另一个重要部分——主动学习(AL)则没有对应的代码,因此也无法进一步深挖了。但无论是ST-SENet,还是STS-HGCN,在设计思路上都有很多可借鉴之处,同时通过阅读代码,加深了对各个组件底层实现的理解。
论文常读常新,每次阅读都会有不一样的感觉,希望大家能融会贯通,举一反三,在此基础上设计出更优秀的模型。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。