当前位置:   article > 正文

论文阅读&STS-HGCN-AL代码详解(一)_temporal embedding

temporal embedding

论文原文:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction


论文地址:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction | IEEE Journals & Magazine | IEEE Xplore


I. 简介


  1. 提出了一种新型的STS-HGCN-AL方案来自动癫痫发作预测,该方案能够通过推断出代表脑电图并探索其不规则的时空反应来解决患者之间的异质性。
  2. 提出了两种变体图卷积:a)残差图卷积(ResGCN)和 b)节律注意力单元(RhythmAtt units)
  3. 对半监督的主动学习策略进行了研究,以自适应地推断患者特异性的最佳前间隔。


图1 本文的主要框架


II. 代码讲解

图2 ST-SENet

如图2所示,ST-SENet由四个模块构成,即:Temporal embedding,Multi-level spectral analysis,Multi-scale temporal analysis和Group convolution squeeze and excitation。

注意到图2左上角,ST-SENet的输入为独立成分(Independent Components, ICs),而非原始EEG信号。按照论文原文,每个通道的EEG信号是由多个源互相影响产生的,而独立成分分析法可以将EEG信号映射到相互独立的IC,每个IC都来自特定皮质区域的局部场活动,这就排除了不同源之间的干扰。例如对于N通道*T采样点EEG信号,使用ICA可将其分为N*T个独立源信号,在不改变信号尺寸的前提下去除了耦合因素。论文采用fastICA方法得到ICs。

图3 Temporal embedding

A. Temporal embedding

时域嵌入模块使用一组Temporal Convolution,即时域卷积层,实现原始ICs的时域嵌入。



  1. if __name__ == "__main__":
  2. x = Variable(torch.randn([128, 1, 19, 1280]))
  3. model = ST_SENet(1, 19, 256)
  4. output = model(x)
  5. print(output.size())

x为产生的模拟数据,其维度为(batch*channels*EEG_electrodes*sample_points)。ST_SENet是一个类,负责论文中ST-SENet的具体实现。现在讨论的是Temporal embedding,因此仅关注ST_SENet的前几行:

  1. class ST_SENet(nn.Module):
  2. def __init__(self, inc, chan_num, si, outc = 64, num_of_layer = 1):
  3. super(ST_SENet, self).__init__()
  4. self.fi = math.floor(math.log2(si))
  5. self.embedding = Embedding_Block(Input_Layer,
  6. Residual_Block,
  7. num_of_layer = num_of_layer,
  8. inc = inc,
  9. outc = 4)

由于论文中使用的数据集采样率为256Hz,因此猜测这里的ST_SENet(1, 19, 256)中的256为采样频率,对应形参si,chan_num应该是EEG通道数目,由于EEG是二维信号,inc设置为1即可。



  1. def Embedding_Block(input_block, Residual_Block, num_of_layer, inc, outc):
  2. layers = []
  3. layers.append(input_block(inc = inc))
  4. for i in range(0, num_of_layer):
  5. layers.append(Residual_Block(inc = int(math.pow(2, i)*outc),
  6. outc = int(math.pow(2, i+1)*outc)))
  7. return nn.Sequential(*layers)



  1. [Input_Layer(
  2. (conv_input): Conv2d(1, 4, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  3. (bn_input): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  4. ), Residual_Block(
  5. (conv_expand): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
  6. (conv1): Conv2d(4, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  7. (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  9. (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. )]


  1. class Input_Layer(nn.Module):
  2. def __init__(self, inc):
  3. super(Input_Layer, self).__init__()
  4. self.conv_input = nn.Conv2d(in_channels = 1,
  5. out_channels = 4,
  6. kernel_size = (1, 3),
  7. stride = 1,
  8. padding = (0, 1),
  9. bias = False)
  10. self.bn_input = nn.BatchNorm2d(4)
  11. self.initialize()
  12. def initialize(self):
  13. for m in self.modules():
  14. if isinstance(m, nn.Conv2d):
  15. nn.init.xavier_uniform_(m.weight, gain = 1)
  16. elif isinstance(m, nn.BatchNorm2d):
  17. nn.init.constant_(m.weight, 1)
  18. nn.init.constant_(m.bias, 0)
  19. def forward(self, x):
  20. output = self.bn_input(self.conv_input(x))
  21. return output
  1. class Residual_Block(nn.Module):
  2. def __init__(self, inc, outc):
  3. super(Residual_Block, self).__init__()
  4. if inc is not outc:
  5. self.conv_expand = nn.Conv2d(in_channels = inc,
  6. out_channels = outc,
  7. kernel_size = 1,
  8. stride = 1,
  9. padding = 0,
  10. bias = False)
  11. else:
  12. self.conv_expand = None
  13. self.conv1 = nn.Conv2d(in_channels = inc,
  14. out_channels = outc,
  15. kernel_size = (1, 3),
  16. stride = 1,
  17. padding = (0, 1),
  18. bias = False)
  19. self.bn1 = nn.BatchNorm2d(outc)
  20. self.conv2 = nn.Conv2d(in_channels = outc,
  21. out_channels = outc,
  22. kernel_size = (1, 3),
  23. stride = 1,
  24. padding = (0, 1),
  25. bias = False)
  26. self.bn2 = nn.BatchNorm2d(outc)
  27. self.initialize()
  28. def initialize(self):
  29. for m in self.modules():
  30. if isinstance(m, nn.Conv2d):
  31. nn.init.xavier_uniform_(m.weight, gain = 1)
  32. elif isinstance(m, nn.BatchNorm2d):
  33. nn.init.constant_(m.weight, 1)
  34. nn.init.constant_(m.bias, 0)
  35. def forward(self, x):
  36. if self.conv_expand is not None:
  37. identity_data = self.conv_expand(x)
  38. else:
  39. identity_data = x
  40. output = self.bn1(self.conv1(x))
  41. output = self.conv2(output)
  42. output = self.bn2(torch.add(output,identity_data))
  43. return output

图4  对应关系







再对比论文中的原文表述:Because convolution operators essentially equate to a lowpass filter [23], the temporal embedding block, that is, successive temporal convolution and batch normalization (BN) operations, is first adopted to infer a patient-specific optimal filter-band for the subsequent analysis.  可见Temporal embedding确实是由时域卷积和BN层构成的。


  1. def forward(self, x):
  2. # Temporal embedding
  3. embedding_x = self.embedding(x)
  4. # concat raw ICs and Temporal embedding ICs
  5. cat_x = torch.cat((embedding_x, x), 1)

查看cat_x的维度为torch.Size([128, 9, 19, 1280]),x的维度为torch.Size([128, 1, 19, 1280]),可见执行时域嵌入产生了8个(由Temporal embedding的卷积核数量决定)尺寸为19*1280的嵌入ICs。

B. Multi-level spectral analysis


图5  多级谱域分析

Multi-level spectral analysis采用A部分中得到的ICs为输入。在ST_SENet的init部分,可见多尺度谱域分析的定义:

self.MultiLevel_Spectral = MultiLevel_Spectral(inc = 4*int(math.pow(2, num_of_layer))+inc)



  1. class MultiLevel_Spectral(nn.Module):
  2. def __init__(self, inc, params_path='./scaling_filter.mat'):
  3. super(MultiLevel_Spectral, self).__init__()
  4. self.filter_length = io.loadmat(params_path)['Lo_D'].shape[1]
  5. self.conv = nn.Conv2d(in_channels = inc,
  6. out_channels = inc*2,
  7. kernel_size = (1, self.filter_length),
  8. stride = (1, 2), padding = 0,
  9. groups = inc,
  10. bias = False)
  11. for m in self.modules():
  12. if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
  13. f = io.loadmat(params_path)
  14. Lo_D, Hi_D = np.flip(f['Lo_D'], axis = 1).astype('float32'), np.flip(f['Hi_D'], axis = 1).astype('float32')
  15. m.weight.data = torch.from_numpy(np.concatenate((Lo_D, Hi_D), axis = 0)).unsqueeze(1).unsqueeze(1).repeat(inc, 1, 1, 1)
  16. m.weight.requires_grad = False
  17. def self_padding(self, x):
  18. return torch.cat((x[:, :, :, -(self.filter_length//2-1):], x, x[:, :, :, 0:(self.filter_length//2-1)]), (self.filter_length//2-1))
  19. def forward(self, x):
  20. out = self.conv(self.self_padding(x))
  21. return out[:, 0::2,:, :], out[:, 1::2, :, :]





图6  小波卷积前后维度的变化


回到ST_SENet的forward部分,可以看到,Multi-level spectral analysis确实是以A部分的ICs为输入(cat_x),查看具体的分解过程:

  1. # Multi-level spectral analysis
  2. for i in range(1, self.fi-2):
  3. if i <= self.fi-7:
  4. if i == 1:
  5. out, _ = self.MultiLevel_Spectral(cat_x)
  6. else:
  7. out, _ = self.MultiLevel_Spectral(out)
  8. elif i == self.fi-6:
  9. if self.fi >= 8:
  10. out, gamma = self.MultiLevel_Spectral(out)
  11. else:
  12. out, gamma = self.MultiLevel_Spectral(cat_x)
  13. elif i == self.fi-5:
  14. out, beta = self.MultiLevel_Spectral(out)
  15. elif i == self.fi-4:
  16. out, alpha = self.MultiLevel_Spectral(out)
  17. elif i == self.fi-3:
  18. delta, theta = self.MultiLevel_Spectral(out)
图7  分解得到各个频段的数据维度


C. Multi-scale temporal analysis

图8  多尺度时域分析

多尺度时域分析仍然采用A部分得到的ICs为输入。在ST_SENet的init部分,可以看到Multi-scale temporal analysis的定义:

  1. self.MultiScale_Temporal_gamma = MultiScale_Temporal(pow(2, self.fi-3)//8, 4*int(math.pow(2, num_of_layer))+inc)
  2. self.MultiScale_Temporal_beta = MultiScale_Temporal(pow(2, self.fi-3)//4, 4*int(math.pow(2, num_of_layer))+inc)
  3. self.MultiScale_Temporal_alpha = MultiScale_Temporal(pow(2, self.fi-3)//2, 4*int(math.pow(2, num_of_layer))+inc)
  4. self.MultiScale_Temporal_theta = MultiScale_Temporal(pow(2, self.fi-3), 4*int(math.pow(2, num_of_layer))+inc)
  5. self.MultiScale_Temporal_delta = MultiScale_Temporal(pow(2, self.fi-3), 4*int(math.pow(2, num_of_layer))+inc)

查看Multi-scale temporal analysis的实现:

  1. class MultiScale_Temporal(nn.Module):
  2. def __init__(self, kernel_size, inc):
  3. super(MultiScale_Temporal, self).__init__()
  4. self.conv = nn.Conv2d(in_channels = inc,
  5. out_channels = inc,
  6. kernel_size = (1, kernel_size),
  7. stride = (1, kernel_size),
  8. padding = (0, 0),
  9. bias = False)
  10. self.bn = nn.BatchNorm2d(inc)
  11. self.elu = nn.ELU(inplace = True)
  12. self.initialize()
  13. def initialize(self):
  14. for m in self.modules():
  15. if isinstance(m, nn.Conv2d):
  16. nn.init.xavier_uniform_(m.weight, gain = 1)
  17. elif isinstance(m, nn.BatchNorm2d):
  18. nn.init.constant_(m.weight, 1)
  19. nn.init.constant_(m.bias, 0)
  20. def forward(self, x):
  21. output = self.elu(self.bn(self.conv(x)))
  22. return output

MultiScale_Temporal由1个卷积层,1个BN层,1个ELU激活函数层构成。这与原文表述一致。原文为:Thus, multiscale temporal analysis, that is, temporal convolution with trainable kernel parameters, BN and, exponential linear unit (ELU) operations, captures temporal embeddings of the dynamic ICs at different scales in a data-driven way.








图9  验证多尺度时域分析输出尺寸

现在回到ST_SENet的forward部分,执行cat操作,将多尺度谱域分析结果(gamma, beta, alpha, theta, delta)与多尺度时域分析结果沿维度1进行串联,因此得到的(x1, x2, x3, x4, x5)应该分别是128*18*19*320,128*18*19*160,128*18*19*80,128*18*19*40,128*18*19*40。

  1. x1 = torch.cat((self.MultiScale_Temporal_gamma(cat_x), gamma), 1)
  2. x2 = torch.cat((self.MultiScale_Temporal_beta(cat_x), beta), 1)
  3. x3 = torch.cat((self.MultiScale_Temporal_alpha(cat_x), alpha), 1)
  4. x4 = torch.cat((self.MultiScale_Temporal_theta(cat_x), theta), 1)
  5. x5 = torch.cat((self.MultiScale_Temporal_delta(cat_x), delta), 1)
图10  验证串联结果

 D. Group convolution squeeze and excitation

图11  gcSE模块


  1. self.gamma_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 7)
  2. self.beta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 7)
  3. self.alpha_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)
  4. self.theta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)
  5. self.delta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)









  1. class SENet(nn.Module):
  2. def __init__(self, inc, outc, kernel_size, reduction = 8):
  3. super(SENet, self).__init__()
  4. self.conv0 = nn.Conv2d(in_channels = inc,
  5. out_channels = outc,
  6. kernel_size = (1, kernel_size),
  7. stride = (1, 1),
  8. padding = (0, kernel_size//2),
  9. groups = 2,
  10. bias = False)
  11. self.bn0 = nn.BatchNorm2d(outc)
  12. self.se0 = SELayer(outc, reduction)
  13. self.conv1 = nn.Conv2d(in_channels = outc,
  14. out_channels = outc,
  15. kernel_size = (1, kernel_size),
  16. stride = (1, 1),
  17. padding = (0, kernel_size//2),
  18. groups = 2,
  19. bias = False)
  20. self.bn1 = nn.BatchNorm2d(outc)
  21. self.se1 = SELayer(outc, reduction)
  22. self.conv2 = nn.Conv2d(in_channels = outc,
  23. out_channels = outc,
  24. kernel_size = (1, kernel_size),
  25. stride = (1, 1),
  26. padding = (0, kernel_size//2),
  27. groups = 2,
  28. bias = False)
  29. self.bn2 = nn.BatchNorm2d(outc)
  30. self.se2= SELayer(outc, reduction)
  31. self.elu = nn.ELU(inplace = False)
  32. self.initialize()
  33. def initialize(self):
  34. for m in self.modules():
  35. if isinstance(m, nn.Conv2d):
  36. nn.init.xavier_uniform_(m.weight, gain = 1)
  37. elif isinstance(m, nn.BatchNorm2d):
  38. nn.init.constant_(m.weight, 1)
  39. nn.init.constant_(m.bias, 0)
  40. def forward(self, x):
  41. out = self.elu(self.se0(self.bn0(self.conv0(x))))
  42. out = self.elu(self.se1(self.bn1(self.conv1(out))))
  43. out = self.elu(self.se2(self.bn2(self.conv2(out))))
  44. return out


  1. class SELayer(nn.Module):
  2. '''
  3. Original SE block, details refer to "Jie Hu et al.: Squeeze-and-Excitation Networks"
  4. '''
  5. def __init__(self, channel, reduction = 16):
  6. super(SELayer, self).__init__()
  7. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  8. self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias = False),
  9. nn.ELU(inplace = True),
  10. nn.Linear(channel // reduction, channel, bias = False),
  11. nn.Sigmoid())
  12. def forward(self, x):
  13. b, c, _, _ = x.size()
  14. y = self.avg_pool(x).view(b, c)
  15. y = self.fc(y).view(b, c, 1, 1)
  16. return x * y.expand_as(x)

具体原理不多介绍。回到SE_Net的forward部分,可以看到对C部分中的(x1, x2, x3, x4, x5),即多级谱域分析和多尺度时域分析串联后的结果,进行三次gcSE操作,与图11对应。

  1. x1 = self.gamma_x(x1)
  2. x2 = self.beta_x(x2)
  3. x3 = self.alpha_x(x3)
  4. x4 = self.theta_x(x4)
  5. x5 = self.delta_x(x5)
图12  SENet输出尺度


图11最后是Temporal pooling,其实实现很简单。


self.average_pooling = nn.AdaptiveAvgPool2d((chan_num, 1))


  1. x1 = self.average_pooling(x1)
  2. x2 = self.average_pooling(x2)
  3. x3 = self.average_pooling(x3)
  4. x4 = self.average_pooling(x4)
  5. x5 = self.average_pooling(x5)

通过自适应平均池化,将输入数据池化到(128, 32, 19, 1),注意nn.AdaptiveAvgPool2d的参数为输出尺寸,而非卷积核尺寸。即在时间维度上进行平均池化,如图13所示:

图13  自适应平均池化输出尺度

对照代码和图11,(x1, x2, x3, x4, x5)还需要经过一个卷积层,该卷积层有64个卷积核,大小为1*1,对应于图11中Temporal pooling左边的TConv+BN+ELU。具体代码实现为:


  1. self.reshapeA = nn.Sequential(nn.Conv2d(in_channels = outc//2,
  2. out_channels = outc,
  3. kernel_size = (1, 1),
  4. stride = (1, 1),
  5. padding = (0, 0),
  6. groups = 1,
  7. bias = False),
  8. nn.BatchNorm2d(outc),
  9. nn.ELU(inplace=False))
  10. self.reshapeB = nn.Sequential(nn.Conv2d(in_channels = outc//2,
  11. out_channels = outc,
  12. kernel_size = (1, 1),
  13. stride = (1, 1),
  14. padding = (0, 0),
  15. groups = 1,
  16. bias = False),
  17. nn.BatchNorm2d(outc),
  18. nn.ELU(inplace = False))
  19. self.reshapeD = nn.Sequential(nn.Conv2d(in_channels = outc//2,
  20. out_channels = outc,
  21. kernel_size = (1, 1),
  22. stride = (1, 1),
  23. padding = (0, 0),
  24. groups = 1,
  25. bias = False),
  26. nn.BatchNorm2d(outc),
  27. nn.ELU(inplace = False))
  28. self.reshapeT = nn.Sequential(nn.Conv2d(in_channels = outc//2,
  29. out_channels = outc,
  30. kernel_size = (1, 1),
  31. stride = (1, 1),
  32. padding = (0, 0),
  33. groups = 1,
  34. bias = False),
  35. nn.BatchNorm2d(outc),
  36. nn.ELU(inplace = False))
  37. self.reshapeG = nn.Sequential(nn.Conv2d(in_channels = outc//2,
  38. out_channels = outc,
  39. kernel_size = (1, 1),
  40. stride = (1, 1),
  41. padding = (0, 0),
  42. groups = 1,
  43. bias = False),
  44. nn.BatchNorm2d(outc),
  45. nn.ELU(inplace = False))


  1. x1 = self.reshapeG(x1)
  2. x2 = self.reshapeB(x2)
  3. x3 = self.reshapeA(x3)
  4. x4 = self.reshapeT(x4)
  5. x5 = self.reshapeD(x5)



此时的(x1, x2, x3, x4, x5)即为图11中的U,即五个频段经过ST-SENet后的ICs,其尺度为(19*64),与原论文中(E*64),E为通道数相对应。

另外需要注意的一点是,ST_SENet的forward操作最后将(x1, x2, x3, x4, x5)串联并返回,所以返回值的尺寸是torch.Size([128, 320, 1, 19]):

return torch.cat((x1, x2, x3, x4, x5), 1).permute(0, 1, 3, 2).contiguous()

这里用到的contiguous()函数是断开返回值与(x1, x2, x3, x4, x5)的联系,即此后(x1, x2, x3, x4, x5)的值的改变也不会影响返回值。


