部分标签数据集生成与过滤特定标签方法
完整代码总结
这段代码的目的是通过构建一个部分标签学习(Partial Label Learning, PLL)框架来生成一个包含部分标签的数据集,并且支持根据给定的标签列表对数据集进行筛选和过滤。代码包含了多个类和函数,主要分为以下几部分:
- 数据预处理与加载:使用 PyTorch 和 torchvision 来加载 CIFAR-10 数据集,并对其进行标准化处理。
- 部分标签数据集的生成:为每个样本生成多个候选标签,并模拟部分标签学习中的标签不确定性。
- 数据集筛选:根据用户提供的标签列表来过滤掉包含特定标签的样本,生成一个新的数据集。
- DataLoader 设置:通过 DataLoader 对数据集进行批量加载,并在训练时进行处理。
各方法与类的解释
1. PartialLabelDataset 类
该类用于生成一个部分标签数据集,每个样本会被赋予一个候选标签集,其中可能包含真实标签以及一些随机标签。
__init__(self, dataset, candidate_size)
:初始化数据集,将输入的原始数据集与候选标签集大小保存为类的属性。candidate_size
表示每个样本的候选标签数量。generate_partial_labels(self)
:为每个样本生成部分标签。每个样本会从真实标签开始,然后添加若干个随机的标签,直到候选标签集的大小为candidate_size
。生成的标签会被打乱顺序,以模拟标签不确定性。__getitem__(self, index)
:获取索引index
对应样本的图像数据、部分标签和真实标签。真实标签是从数据集中直接获取的,部分标签是根据generate_partial_labels()
方法生成的。__len__(self)
:返回数据集中样本的数量。
2. FilteredPartialLabelDataset 类
该类用于过滤掉原始部分标签数据集中的特定标签样本,并根据过滤后的数据生成新的数据集。
__init__(self, dataset, partial_labels, filtered_indices)
:初始化该类时,需要输入原始数据集、完整的部分标签列表以及要保留的样本索引列表(即不包含过滤标签的样本)。__getitem__(self, index)
:根据过滤后的索引,从原始数据集中获取图像和标签数据。__len__(self)
:返回筛选后的样本数量。
3. filter_partial_label_dataset 函数
这个函数用于对原始部分标签数据集进行标签筛选,去掉包含特定标签的样本,并返回过滤后的数据集和 DataLoader。
dataset
:原始数据集(如 CIFAR-10)。partial_labels
:包含完整部分标签的列表,函数会基于此生成新的部分标签数据集。candidate_size
:每个样本的候选标签集大小。filtered_labels
:一个标签列表,表示需要从部分标签中排除的标签。batch_size
:DataLoader 的批次大小。shuffle
:是否在 DataLoader 中打乱数据。num_workers
:DataLoader 的工作线程数。
函数首先根据 filtered_labels
过滤掉部分标签中包含这些标签的样本,接着根据过滤后的样本索引创建一个新的 FilteredPartialLabelDataset
。最终返回该新的数据集和对应的 DataLoader。
4. main 函数
该函数是代码的入口,负责生成部分标签数据集并创建 DataLoader。
- 通过
PartialLabelDataset
类生成一个包含部分标签的数据集(候选标签集大小为3)。 - 创建一个 DataLoader,用于批量加载部分标签数据集。
- 打印出部分标签数据集的一个批次样本的形状和标签信息。
在 main()
函数中,partial_label_dataset
被用来生成部分标签数据集,并且通过 filter_partial_label_dataset
函数对数据集进行标签过滤,排除包含标签 [5, 6, 7, 8, 9]
的样本。
代码流程图
-
数据加载与预处理:
- 使用
torchvision.datasets.CIFAR10
下载并加载 CIFAR-10 数据集。 - 对图像进行标准化处理(均值和标准差为0.5)。
- 使用
-
生成部分标签数据集:
- 在
PartialLabelDataset
中为每个样本生成多个候选标签(候选标签数为3),这些标签包括真实标签及随机标签。 - 使用
generate_partial_labels()
方法生成候选标签,并打乱顺序。
- 在
-
数据筛选:
- 使用
filter_partial_label_dataset
函数,根据用户提供的标签列表(如[5, 6, 7, 8, 9]
)过滤掉部分标签中包含这些标签的样本,创建新的数据集。
- 使用
-
数据加载器:
- 通过
DataLoader
创建数据加载器,使得在训练过程中可以批量读取数据。
- 通过
-
输出样本信息:
- 在
main()
函数中打印出部分标签的一个批次示例,包括图像的形状、部分标签和真实标签。
- 在
优点和可扩展性
- 部分标签学习:这段代码模拟了部分标签学习的场景,其中每个样本都有多个候选标签,这为部分标签学习任务提供了一个基础框架。
- 灵活的标签过滤:通过
filter_partial_label_dataset
函数,用户可以方便地过滤掉特定标签的样本。 - 可扩展性:可以将这个框架扩展到其他数据集(如 CIFAR-100、ImageNet 等),并灵活调整候选标签大小和过滤标签。
总结
这段代码提供了一个部分标签学习框架,可以用来处理具有部分标签的不完整数据集,并提供了一种方法来筛选数据集中的特定标签。通过生成候选标签和对数据进行过滤,代码实现了部分标签学习任务的数据预处理与加载,为相关研究和应用提供了有效支持。
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 下载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 定义合并后的部分标签数据集类
class PartialLabelDataset(Dataset):
def __init__(self, dataset, candidate_size):
"""
初始化部分标签数据集
:param dataset: 原始数据集对象(如 CIFAR-10)
:param candidate_size: 候选标签集的大小
:param filtered_labels: 不得存在于部分标签中的标签列表(可选)
"""
self.dataset = dataset
self.candidate_size = candidate_size
self.num_classes = len(dataset.classes)
self.targets = dataset.targets
self.partial_labels = self.generate_partial_labels()
def generate_partial_labels(self):
"""
为每个图像生成部分标签
:param filtered_labels: 不得存在于部分标签中的标签列表(可选)
:return: 部分标签列表
"""
partial_labels = []
for target in self.targets:
candidates = [target]
while len(candidates) < self.candidate_size:
random_label = np.random.randint(0, self.num_classes)
if random_label not in candidates :
candidates.append(random_label)
#打乱候选标签
np.random.shuffle(candidates)
partial_labels.append(candidates)
return partial_labels
def __getitem__(self, index):
image, _ = self.dataset[index]
partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)
true_label = torch.tensor(self.targets[index], dtype=torch.long) # 真实标签
return image, partial_label, true_label
def __len__(self):
return len(self.dataset)
class FilteredPartialLabelDataset(Dataset):
def __init__(self, dataset, partial_labels, filtered_indices):
"""
初始化筛选后的部分标签数据集
:param dataset: 原始数据集对象
:param partial_labels: 完整部分标签列表
:param filtered_indices: 筛选后的样本索引列表
"""
self.dataset = dataset
self.partial_labels = [partial_labels[i] for i in filtered_indices]
self.indices = filtered_indices
def __getitem__(self, index):
original_index = self.indices[index] #
image, _ = self.dataset[original_index]
partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)
true_label = torch.tensor(self.dataset.targets[original_index], dtype=torch.long) # 真实标签
return image, partial_label, true_label #表示这个类实例化之后,返回的就是这个样本的图像和部分标签
def __len__(self):
return len(self.indices)
def filter_partial_label_dataset(dataset, partial_labels, candidate_size=3, filtered_labels=None, batch_size=64, shuffle=True, num_workers=2):
"""
过滤数据集以排除部分标签中含有任何 filtered_labels 的样本。
:param dataset: 原始数据集(例如 CIFAR-10)
:param candidate_size: 候选标签集的大小(默认:3)
:param filtered_labels: 不得存在于部分标签中的标签列表
:param batch_size: DataLoader 的批次大小(默认:4)
:param shuffle: 是否在 DataLoader 中打乱数据(默认:True)
:param num_workers: DataLoader 的工作线程数(默认:2)
:return: (过滤后的数据集, DataLoader) 元组
"""
if filtered_labels is None:
raise ValueError("Filtered labels must be specified.")
# 将部分标签转换为 NumPy 数组以进行高效过滤
partial_labels_np = np.array(partial_labels)
# 创建样本中不包含任何 filtered_labels 的掩码
filtered_labels_mask = np.any(np.isin(partial_labels_np, filtered_labels), axis=1)
final_mask = ~filtered_labels_mask # 这个索引列中,只有不含要过滤的标签的样本才为 True
# 获取过滤后的索引
filtered_indices = np.where(final_mask)[0] # 过滤后的样本的索引,每个值对是该样本在原始数据集中的索引,可以据此得到该样本的真实标签
# 创建过滤后的部分标签数据集
new_partial_label_dataset = FilteredPartialLabelDataset(dataset, partial_labels, filtered_indices)
# 创建 DataLoader
new_partial_label_loader = DataLoader(new_partial_label_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
# 打印过滤后样本的信息
print("过滤后的样本数量:", len(filtered_indices))
# 可选:打印一个批次的示例
for images, partial_labels_batch , true_labels_batch in new_partial_label_loader:
print("新图像的形状:", images.shape)
print("新部分标签:", partial_labels_batch)
print("新真实标签:", true_labels_batch)
break
return new_partial_label_dataset, new_partial_label_loader
# 主函数:生成部分标签数据集并过滤
def main():
# 生成部分标签数据集,不包含标签5、6、7、8、9
partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)
# 创建 DataLoader
trainloader = DataLoader(partial_label_dataset, batch_size=4, shuffle=True, num_workers=2)
# 打印部分标签示例
for images, partial_labels, true_labels in trainloader:
print("图像的形状:", images.shape)
print("部分标签:", partial_labels)
print("真实标签:", true_labels)
break
if __name__ == '__main__':
main()
partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)
partial_labels = partial_label_dataset.generate_partial_labels()
filter_partial_label_dataset(trainset, partial_labels, candidate_size=3, filtered_labels=[5, 6, 7, 8, 9], batch_size=4, shuffle=True, num_workers=2)