当前位置: 首页 > article >正文

基于MNE的EEGNet 神经网络的脑电信号分类实战(附完整源码)

利用MNE中的EEG数据,进行EEGNet神经网络的脑电信号分类实现:

代码:

代码主要包括一下几个步骤:
1)从MNE中加载脑电信号,并进行相应的预处理操作,得到训练集、验证集以及测试集,每个集中都包括数据和标签;
2)基于tensorflow构建EEGNet网络模型;
3)编译模型,配置损失函数、优化器和评估指标等,并进行模型训练和预测;
4)绘制训练集和验证集的损失曲线以及训练集和验证集的准确度曲线。
代码如下:

import mne
import os
from pathlib import Path
import numpy as np
from keras.src.utils import np_utils

from mne import io
from mne.datasets import sample
import matplotlib.pyplot as plt
import pathlib

from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
from keras.src.callbacks import ModelCheckpoint

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression


def EEGNet(nb_classes, Chans=64, Samples=128,
           dropoutRate=0.5, kernelLength=64,
           F1=8, D=2, F2=16, norm_rate=0.25,
           dropout_type='Dropout'):
    """
    EEGNet模型的实现。

    参数:
    - nb_classes: int, 输出类别的数量。
    - Chans: int, 通道数,默认为64。
    - Samples: int, 每个通道的样本数,默认为128。
    - dropoutRate: float, Dropout率,默认为0.5。
    - kernelLength: int, 卷积核的长度,默认为64。
    - F1: int, 第一个卷积层的滤波器数量,默认为8。
    - D: int, 深度乘法器,默认为2。
    - F2: int, 第二个卷积层的滤波器数量,默认为16。
    - norm_rate: float, 权重范数约束,默认为0.25。
    - dropout_type: str, Dropout类型,默认为'Dropout'。

    返回:
    - Model: Keras模型对象。
    """

    # 根据dropout_type参数确定使用哪种Dropout方式
    if dropout_type == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropout_type == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropout_type must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')

    # 定义模型的输入层
    input1 = Input(shape=(Chans, Samples, 1))

    # 第一个卷积块
    block1 = Conv2D(F1, (1, kernelLength), padding='same',
                       input_shape=(Chans, Samples, 1),
                       use_bias=False)(input1)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((Chans, 1), use_bias=False,
                               depth_multiplier=D,
                               depthwise_constraint=max_norm(1.))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = dropoutType(dropoutRate)(block1)

    # 第二个卷积块
    block2 = SeparableConv2D(F2, (1, 16),
                               use_bias=False, padding='same')(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = dropoutType(dropoutRate)(block2)

    # 将卷积块的输出展平以便输入到全连接层
    flatten = Flatten(name='flatten')(block2)

    # 定义全连接层
    dense = Dense(nb_classes, name='dense', kernel_constraint=max_norm(norm_rate))(flatten)
    softmax = Activation('softmax', name='softmax')(dense)

    # 创建并返回模型
    return Model(inputs=input1, outputs=softmax)


def get_data4EEGNet(kernels, chans, samples):
    """
    为EEGNet模型准备数据。

    该函数从指定的文件路径中读取原始EEG数据和事件数据,进行预处理,
    包括滤波、选择通道、分割数据集,并将数据集按给定的通道、核数和样本数进行重塑。

    参数:
    kernels - 数据集中的核数量。
    chans - 数据集中的通道数量。
    samples - 数据集中的样本数量。

    返回:
    X_train, X_validate, X_test, y_train, y_validate, y_test - 分别是训练、验证和测试数据集,
    以及相应的标签。
    """
    # 设置图像数据格式,确保数据维度顺序正确
    K.set_image_data_format('channels_last')

    # 定义数据路径
    data_path = Path("C:\\Users\\72671\\mne_data\\MNE-sample-data")

    # 定义原始数据和事件数据的文件路径
    raw_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")
    event_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")

    # 定义时间范围和事件ID
    tmin, tmax = -0., 1
    event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

    # 读取并预处理原始数据
    raw = io.Raw(raw_fname, preload=True, verbose=False)
    raw.filter(2, None, method='iir')
    events = mne.read_events(event_fname)

    # 设置无效通道并选择所需通道类型
    raw.info['bads'] = ['MEG 2443']
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

    # 创建epochs数据集
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, picks=picks, baseline=None,
                        preload=True, verbose=False)
    labels = epochs.events[:, -1]

    # 获取数据并进行缩放
    X = epochs.get_data(copy=False) * 1e6
    y = labels

    # 分割数据集为训练、验证和测试集
    X_train = X[0:144, ]
    y_train = y[0:144]
    X_validate = X[144:216, ]
    y_validate = y[144:216]
    X_test = X[216:, ]
    y_test = y[216:]

    # 将训练、验证和测试数据集中的标签转换为one-hot编码
    # 减1是因为标签通常从1开始计数,而one-hot编码需要从0开始
    y_train = np_utils.to_categorical(y_train-1)
    y_validate = np_utils.to_categorical(y_validate-1)
    y_test = np_utils.to_categorical(y_test-1)


    # 重塑数据集以匹配EEGNet模型的输入要求
    X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
    X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
    X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

    # 返回准备好的数据集
    return X_train, X_validate, X_test, y_train, y_validate, y_test


#########################################################################
# 定义模型参数
kernels, chans, samples = 1, 60, 151
# 获取预处理后的EEG数据集
X_train, X_validate, X_test, y_train, y_validate, y_test = get_data4EEGNet(kernels, chans, samples)

# 初始化EEGNet模型
model = EEGNet(nb_classes=4, Chans=chans, Samples=samples, dropoutRate=0.5,
               kernelLength=32, F1=8, D=2, F2=16, dropout_type='Dropout')

# 编译模型,配置损失函数、优化器和评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 设置模型检查点以保存最佳模型
checkpointer = ModelCheckpoint(filepath='./models/EEGNet_best_model.h5', verbose=1, save_best_only=True)
# 定义类别权重
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}

# 训练模型
fittedModel = model.fit(X_train, y_train, batch_size=32, epochs=500, verbose=2,
                        validation_data=(X_validate, y_validate),
                        callbacks=[checkpointer], class_weight=class_weights)

# 加载最佳模型权重
model.load_weights('./models/EEGNet_best_model.h5')

# 对测试集进行预测
probs = model.predict(X_test)
# 获取预测标签
preds = probs.argmax(axis=-1)
# 计算分类准确率
acc = np.mean(preds == y_test.argmax(axis=-1))

# 输出分类准确率
print("Classification accuracy: %f " % (acc))


# 获取训练历史
history = fittedModel.history

# 绘制训练集和验证集的损失曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# 绘制训练集和验证集的准确度曲线
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

效果如下:

在这里插入图片描述

参考资料:

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project


http://www.kler.cn/a/441861.html

相关文章:

  • wireshark工具简介
  • 回归算法、聚类算法、决策树、随机森林、神经网络
  • vulnhub靶机(ReconForce)
  • KubeSphere 与 Pig 微服务平台的整合与优化:全流程容器化部署实践
  • 使用 Blazor 和 Elsa Workflows 作为引擎的工作流系统开发
  • 消息队列实战指南:三大MQ 与 Kafka 适用场景全解析
  • CAD xy坐标标注(跟随鼠标位置实时移动)——C#插件实现
  • dify智能体系列:selenium有啥好玩的?
  • 如何为IntelliJ IDEA配置JVM参数
  • springboot中Controller内文件上传到本地以及阿里云
  • 【Prompt Engineering】2.迭代优化
  • 【附源码】Electron Windows桌面壁纸开发中的 CommonJS 和 ES Module 引入问题以及 Webpack 如何处理这种兼容
  • 判题机的开发(代码沙箱、三种模式、工厂模式、策略模式优化、代理模式)
  • 单片机:实现蜂鸣器数码管的显示(附带源码)
  • Numpy基本介绍
  • Leetcode打卡:形成目标字符串需要的最少字符串数II
  • 如何在 Linux 服务器上部署 Pydio Cells 教程
  • STM32F407ZGT6-UCOSIII笔记7:优先级反转现象
  • 【图形渲染】【Unity Shader】【Nvidia CG】有用的参考资料链接
  • Composer指定php版本执行(windows)
  • git branch -r(--remotes )显示你本地仓库知道的所有 远程分支 的列表
  • Hadoop是什么?Hadoop介绍
  • workman服务端开发模式-应用开发-总架构逻辑说明
  • 虚拟现实辅助工程技术在航空领域的应用
  • git pull 和 git pull --rebase 区别
  • 初见react