Pytorch:Dataset的加载
文章目录
- 1. 准备数据
- 2. 创建自定义 `Dataset` 类
- 3. 实例化数据集对象
- 4. 使用 `DataLoader` 加载数据
- 5. 迭代数据集
- 6. 预处理和数据增强(可选)
- 7. 多线程加载(可选)
自定义数据集的加载在
PyTorch
中主要为以下几个步骤:
1. 准备数据
- 从文件中读取数据
- 对数据进行预处理
- 给数据打上标签label
- 合并数据(根据实际而定)
- 划分训练集和验证集(可使用
sklearn.model_selection
模块的train_test_split
函数)
2. 创建自定义 Dataset
类
创建一个继承自 torch.utils.data.Dataset
的自定义类,主要作用:
- 封装数据:
ECGDataset
类封装了数据和标签,使得它们可以作为一个整体被处理。 - 提供数据访问接口:通过实现
__getitem__
方法,ECGDataset
类提供了一个标准化的方式来访问数据集中的单个样本。 - 与
DataLoader
协同工作:Dataset
类与 PyTorch 的DataLoader
类紧密集成,DataLoader
可以利用Dataset
类提供的方法来实现批量加载、打乱数据、多线程加载等功能。
这个类需要实现两个方法:__len__
和 __getitem__
。
__len__
方法返回数据集中样本的数量。__getitem__
方法根据索引返回数据集中的一个样本。
如果有一个心电(ECG)数据集,自定义 Dataset
类可以如下:
from torch.utils.data import Dataset
class ECGDataset(Dataset):
def __init__(self, ecg_data, labels):
self.ecg_data = ecg_data
self.labels = labels
# 计算样本的数量
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
ecg_sample = self.ecg_data[idx]
label = self.labels[idx]
#最后返回读取到的数据,记住返回一定要是tensor的形式
return ecg_sample, label
或者是以这种方式(加入transform参数):
from torch.utils.data import Dataset
class ECGDataset(Dataset):
def __init__(self, ecg_data, labels,transform=None):
self.ecg_data = ecg_data
self.labels = labels
self.transform = transform
# 计算样本的数量
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
ecg_sample = self.ecg_data[idx]
label = self.labels[idx]
if self.transform:
ecg_sample = self.transform(ecg_sample)
#最后返回读取到的数据,记住返回一定要是tensor的形式
return ecg_sample, label
ps
: transforms.Compose
通常用于图像数据的预处理,如调整大小、裁剪、翻转和归一化等操作。然而,心电信号(ECG)是一维时间序列数据,不是二维图像数据,因此不能直接应用上述为图像设计的 transforms
。
对于心电信号,我们通常会采用不同的预处理方法,一些适用于心电信号的常见预处理 transforms
:
为心电信号定义一个简单的预处理流程:
import numpy as np
from scipy.signal import butter, filtfilt
# 定义心电信号预处理的 transforms
class ECGTransform:
def __init__(self, sample_rate, lowcut, highcut, filter_order, segment_length):
self.sample_rate = sample_rate
self.lowcut = lowcut
self.highcut = highcut
self.filter_order = filter_order
self.segment_length = segment_length
def bandpass_filter(self, ecg_signal):
# 定义带通滤波器参数
nyq = 0.5 * self.sample_rate
low = self.lowcut / nyq
high = self.highcut / nyq
b, a = butter(self.filter_order, [low, high], btype='band')
# 应用滤波器
filtered_signal = filtfilt(b, a, ecg_signal)
return filtered_signal
def standardize(self, ecg_signal):
# 标准化信号
return (ecg_signal - np.mean(ecg_signal)) / np.std(ecg_signal)
def segment_signal(self, ecg_signal):
# 将信号分割成固定长度的片段
segments = []
for start in range(0, len(ecg_signal) - self.segment_length, self.segment_length):
segment = ecg_signal[start:start + self.segment_length]
segments.append(segment)
return np.array(segments)
# 使用预处理 transforms
transform = ECGTransform(sample_rate=250, lowcut=0.5, highcut=15.0, filter_order=5, segment_length=5000)
ecg_signal = ... # 加载心电信号数据
filtered_ecg = transform.bandpass_filter(ecg_signal)
standardized_ecg = transform.standardize(filtered_ecg)
segmented_ecg = transform.segment_signal(standardized_ecg)
在这个示例中,我们创建了一个 ECGTransform
类,它包含带通滤波、标准化和信号分割的方法。
3. 实例化数据集对象
使用你的数据(特征和标签)来创建 Dataset
类的实例,来创建数据集对象。
# 假设 x 是特征数据,y 是标签数据
ecg_dataset = ECGDataset(x, y)
4. 使用 DataLoader
加载数据
使用 torch.utils.data.DataLoader
来包装你的数据集对象,创建数据加载器。DataLoader
可以提供额外的功能,如自动打乱数据、批量加载、多线程加载等。
from torch.utils.data import DataLoader
# 创建 DataLoader 实例
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True)
在这个例子中,batch_size=32
表示每次迭代返回 32 个样本的批次,shuffle=True
表示在每个 epoch 开始时打乱数据。
5. 迭代数据集
在你的训练或验证循环中,你可以迭代 DataLoader
实例来获取数据。
for epoch in range(num_epochs):
for batch_idx, (ecg_samples, labels) in enumerate(data_loader):
pass
在这个循环中,ecg_samples
和 labels
是从数据集中加载的批次数据和标签。
6. 预处理和数据增强(可选)
在自定义 Dataset
类中,你可以添加任何特定的预处理或数据增强步骤。这些步骤将在 __getitem__
方法中执行,确保每个样本在返回之前都经过了适当的处理。
7. 多线程加载(可选)
DataLoader
还支持多线程加载数据,可以通过设置 num_workers
参数来实现。
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True, num_workers=4)
num_workers=4
表示使用 4 个进程来加载数据,这可以显著提高数据加载的效率。