当前位置:   article > 正文

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码_eegnet代码

eegnet代码

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码
脑电信号采集设备是由NT9200-32D型号脑电图仪和NeuSen W系列无线脑电采集系统组成,采集后的信号用Matlab打开,保存在结构体数据中,采集到的原始信号形式是:16x640000 double,最开始对数据进行手动分段分成[280,16,1000],280指trials,22指channels,1000指 samples,
整个代码可分为:**数据切分,搭建网络,训练数据,测试数据,**四个部分
1.导入包

import numpy as np
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
import scipy.io
from matplotlib import pyplot as plt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2.数据切分

K.set_image_data_format('channels_last')
samplesfile = scipy.io.loadmat('F:/holiday_code/attention/TSA/data/foursecond.mat')
X = samplesfile['eeg']#提取数组,结构体名称是eeg
event_id = dict(l=1, m=2, lm=3, ml=4)#四分类运动想象数据
# Setup for reading the raw data
labels = samplesfile['Mark']#加载标签数据
y = labels[:,-1]#标签数据
kernels, chans, samples = 1, 16, 1000

# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:140, ]
Y_train = y[0:140]
X_validate = X[140:210, ]
Y_validate = y[140:210]
X_test = X[210:, ]
Y_test = y[210:]
#把标签数据转换成one-hot编码
Y_train = np_utils.to_categorical(Y_train - 1)
Y_validate = np_utils.to_categorical(Y_validate - 1)
Y_test = np_utils.to_categorical(Y_test - 1)
#根据网络结构设置数据的输入形式(trials, channels, samples, kernels)
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

4.搭建网络

#导入需要的库
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
def EEGNet(nb_classes, Chans = 16, Samples = 1000,
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1 = Input(shape = (Chans, Samples, 1))
    print("input shape", input1.shape, Chans, Samples, kernLength)
    ##################################################################
    block1 = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((Chans, 1), use_bias = False,
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = dropoutType(dropoutRate)(block1)

    block2 = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = dropoutType(dropoutRate)(block2)
    flatten = Flatten(name = 'flatten')(block2)
    
    dense = Dense(nb_classes, name = 'dense',
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax = Activation('softmax', name = 'softmax')(dense)

    return Model(inputs=input1, outputs=softmax)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

5.训练模型

model = EEGNet(nb_classes = 4, Chans = 16, Samples = 1000,
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,
               dropoutType = 'Dropout')
model.compile(loss='categorical_crossentropy', optimizer='adam',
              metrics=['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5', verbose=1,
                               save_best_only=True)
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,
                        verbose=2, validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer], class_weight=class_weights)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

6.测试模型

model.load_weights('F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5')
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))

# plot the accuracy and loss graph
plt.plot(fittedModel.history['accuracy'])
plt.plot(fittedModel.history['val_accuracy'])
plt.plot(fittedModel.history['loss'])
plt.plot(fittedModel.history['val_loss'])
plt.title('acc & loss')
plt.xlabel('epoch')
plt.legend(['acc', 'val_acc','loss','val_loss'], loc='upper right')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

7.分类结果
在这里插入图片描述
整个网络框架大概就是这样,这是其中一个被试的分类结果,属于分类效果比较好的,其他被试可能由于数据质量,网络结构等原因分类效果不是很理想,考虑数据增强以及网络结构优化去提高分类准确率。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/787264
推荐阅读
相关标签
  

闽ICP备14008679号