当前位置: 首页 > article >正文

Pytorch:Dataset的加载

文章目录

      • 1. 准备数据
      • 2. 创建自定义 `Dataset` 类
      • 3. 实例化数据集对象
      • 4. 使用 `DataLoader` 加载数据
      • 5. 迭代数据集
      • 6. 预处理和数据增强(可选)
      • 7. 多线程加载(可选)

自定义数据集的加载在 PyTorch 中主要为以下几个步骤:

1. 准备数据

  1. 从文件中读取数据
  2. 对数据进行预处理
  3. 给数据打上标签label
  4. 合并数据(根据实际而定)
  5. 划分训练集和验证集(可使用sklearn.model_selection模块的train_test_split函数)

2. 创建自定义 Dataset

创建一个继承自 torch.utils.data.Dataset 的自定义类,主要作用:

  1. 封装数据ECGDataset 类封装了数据和标签,使得它们可以作为一个整体被处理。
  2. 提供数据访问接口:通过实现 __getitem__ 方法,ECGDataset 类提供了一个标准化的方式来访问数据集中的单个样本。
  3. DataLoader 协同工作Dataset 类与 PyTorch 的 DataLoader 类紧密集成,DataLoader 可以利用 Dataset 类提供的方法来实现批量加载、打乱数据、多线程加载等功能。

这个类需要实现两个方法:__len____getitem__

  • __len__ 方法返回数据集中样本的数量。
  • __getitem__ 方法根据索引返回数据集中的一个样本。

如果有一个心电(ECG)数据集,自定义 Dataset 类可以如下:

from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, ecg_data, labels):
        self.ecg_data = ecg_data
        self.labels = labels

    # 计算样本的数量
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        ecg_sample = self.ecg_data[idx]
        label = self.labels[idx]
        #最后返回读取到的数据,记住返回一定要是tensor的形式
        return ecg_sample, label

或者是以这种方式(加入transform参数):

from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, ecg_data, labels,transform=None):
        self.ecg_data = ecg_data
        self.labels = labels
        self.transform = transform

    # 计算样本的数量
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        ecg_sample = self.ecg_data[idx]
        label = self.labels[idx]
        if self.transform:
            ecg_sample = self.transform(ecg_sample)
        #最后返回读取到的数据,记住返回一定要是tensor的形式
        return ecg_sample, label

pstransforms.Compose 通常用于图像数据的预处理,如调整大小、裁剪、翻转和归一化等操作。然而,心电信号(ECG)是一维时间序列数据,不是二维图像数据,因此不能直接应用上述为图像设计的 transforms

对于心电信号,我们通常会采用不同的预处理方法,一些适用于心电信号的常见预处理 transforms

为心电信号定义一个简单的预处理流程:

import numpy as np
from scipy.signal import butter, filtfilt

# 定义心电信号预处理的 transforms
class ECGTransform:
    def __init__(self, sample_rate, lowcut, highcut, filter_order, segment_length):
        self.sample_rate = sample_rate
        self.lowcut = lowcut
        self.highcut = highcut
        self.filter_order = filter_order
        self.segment_length = segment_length

    def bandpass_filter(self, ecg_signal):
        # 定义带通滤波器参数
        nyq = 0.5 * self.sample_rate
        low = self.lowcut / nyq
        high = self.highcut / nyq
        b, a = butter(self.filter_order, [low, high], btype='band')
        # 应用滤波器
        filtered_signal = filtfilt(b, a, ecg_signal)
        return filtered_signal

    def standardize(self, ecg_signal):
        # 标准化信号
        return (ecg_signal - np.mean(ecg_signal)) / np.std(ecg_signal)

    def segment_signal(self, ecg_signal):
        # 将信号分割成固定长度的片段
        segments = []
        for start in range(0, len(ecg_signal) - self.segment_length, self.segment_length):
            segment = ecg_signal[start:start + self.segment_length]
            segments.append(segment)
        return np.array(segments)

# 使用预处理 transforms
transform = ECGTransform(sample_rate=250, lowcut=0.5, highcut=15.0, filter_order=5, segment_length=5000)
ecg_signal = ...  # 加载心电信号数据
filtered_ecg = transform.bandpass_filter(ecg_signal)
standardized_ecg = transform.standardize(filtered_ecg)
segmented_ecg = transform.segment_signal(standardized_ecg)

在这个示例中,我们创建了一个 ECGTransform 类,它包含带通滤波、标准化和信号分割的方法。

3. 实例化数据集对象

使用你的数据(特征和标签)来创建 Dataset 类的实例,来创建数据集对象。

# 假设 x 是特征数据,y 是标签数据
ecg_dataset = ECGDataset(x, y)

4. 使用 DataLoader 加载数据

使用 torch.utils.data.DataLoader 来包装你的数据集对象,创建数据加载器。DataLoader 可以提供额外的功能,如自动打乱数据、批量加载、多线程加载等。

from torch.utils.data import DataLoader

# 创建 DataLoader 实例
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True)

在这个例子中,batch_size=32 表示每次迭代返回 32 个样本的批次,shuffle=True 表示在每个 epoch 开始时打乱数据。

5. 迭代数据集

在你的训练或验证循环中,你可以迭代 DataLoader 实例来获取数据。

for epoch in range(num_epochs):
    for batch_idx, (ecg_samples, labels) in enumerate(data_loader):
        pass

在这个循环中,ecg_sampleslabels 是从数据集中加载的批次数据和标签。

6. 预处理和数据增强(可选)

在自定义 Dataset 类中,你可以添加任何特定的预处理或数据增强步骤。这些步骤将在 __getitem__ 方法中执行,确保每个样本在返回之前都经过了适当的处理。

7. 多线程加载(可选)

DataLoader 还支持多线程加载数据,可以通过设置 num_workers 参数来实现。

data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True, num_workers=4)

num_workers=4 表示使用 4 个进程来加载数据,这可以显著提高数据加载的效率。


http://www.kler.cn/a/588734.html

相关文章:

  • 百度贴吧IP和ID是什么意思?怎么查看
  • NPU、边缘计算与算力都是什么啊?
  • [leetcode] 面试经典 150 题——篇3:滑动窗口
  • 一分钟了解深度学习
  • Lisp语言的网络管理
  • 利用Java爬虫根据关键词获取商品列表:实战指南
  • 一份C#的笔试题及答案
  • 【NLP】 4. NLP项目流程与上下文窗口大小参数的影响
  • Kafka可视化工具KafkaTool工具的使用
  • Lua语言的嵌入式调试
  • qt 自带虚拟键盘的编译使用记录
  • 深入解析 React Diff 算法:原理、优化与实践
  • C或C++中实现数据结构课程中的链表、数组、树和图
  • matlab 模糊pid实现温度控制
  • nginx请求限流设置:常见的有基于 IP 地址的限流、基于请求速率的限流以及基于连接数的限流
  • Windows系统中安装Rust工具链方法
  • 数据结构篇——树(1)
  • 人工智能中神经网络是如何进行学习的
  • 1.Windows+vscode+cline+MCP配置
  • 传感云揭秘:边缘计算的革新力量