当前位置: 首页 > article >正文

部分标签数据集生成与过滤特定标签方法

完整代码总结

这段代码的目的是通过构建一个部分标签学习(Partial Label Learning, PLL)框架来生成一个包含部分标签的数据集,并且支持根据给定的标签列表对数据集进行筛选和过滤。代码包含了多个类和函数,主要分为以下几部分:

  1. 数据预处理与加载:使用 PyTorch 和 torchvision 来加载 CIFAR-10 数据集,并对其进行标准化处理。
  2. 部分标签数据集的生成:为每个样本生成多个候选标签,并模拟部分标签学习中的标签不确定性。
  3. 数据集筛选:根据用户提供的标签列表来过滤掉包含特定标签的样本,生成一个新的数据集。
  4. DataLoader 设置:通过 DataLoader 对数据集进行批量加载,并在训练时进行处理。

各方法与类的解释

1. PartialLabelDataset 类

该类用于生成一个部分标签数据集,每个样本会被赋予一个候选标签集,其中可能包含真实标签以及一些随机标签。

  • __init__(self, dataset, candidate_size):初始化数据集,将输入的原始数据集与候选标签集大小保存为类的属性。candidate_size 表示每个样本的候选标签数量。
  • generate_partial_labels(self):为每个样本生成部分标签。每个样本会从真实标签开始,然后添加若干个随机的标签,直到候选标签集的大小为 candidate_size。生成的标签会被打乱顺序,以模拟标签不确定性。
  • __getitem__(self, index):获取索引 index 对应样本的图像数据、部分标签和真实标签。真实标签是从数据集中直接获取的,部分标签是根据 generate_partial_labels() 方法生成的。
  • __len__(self):返回数据集中样本的数量。
2. FilteredPartialLabelDataset 类

该类用于过滤掉原始部分标签数据集中的特定标签样本,并根据过滤后的数据生成新的数据集。

  • __init__(self, dataset, partial_labels, filtered_indices):初始化该类时,需要输入原始数据集、完整的部分标签列表以及要保留的样本索引列表(即不包含过滤标签的样本)。
  • __getitem__(self, index):根据过滤后的索引,从原始数据集中获取图像和标签数据。
  • __len__(self):返回筛选后的样本数量。
3. filter_partial_label_dataset 函数

这个函数用于对原始部分标签数据集进行标签筛选,去掉包含特定标签的样本,并返回过滤后的数据集和 DataLoader。

  • dataset:原始数据集(如 CIFAR-10)。
  • partial_labels:包含完整部分标签的列表,函数会基于此生成新的部分标签数据集。
  • candidate_size:每个样本的候选标签集大小。
  • filtered_labels:一个标签列表,表示需要从部分标签中排除的标签。
  • batch_size:DataLoader 的批次大小。
  • shuffle:是否在 DataLoader 中打乱数据。
  • num_workers:DataLoader 的工作线程数。

函数首先根据 filtered_labels 过滤掉部分标签中包含这些标签的样本,接着根据过滤后的样本索引创建一个新的 FilteredPartialLabelDataset。最终返回该新的数据集和对应的 DataLoader。

4. main 函数

该函数是代码的入口,负责生成部分标签数据集并创建 DataLoader。

  • 通过 PartialLabelDataset 类生成一个包含部分标签的数据集(候选标签集大小为3)。
  • 创建一个 DataLoader,用于批量加载部分标签数据集。
  • 打印出部分标签数据集的一个批次样本的形状和标签信息。

main() 函数中,partial_label_dataset 被用来生成部分标签数据集,并且通过 filter_partial_label_dataset 函数对数据集进行标签过滤,排除包含标签 [5, 6, 7, 8, 9] 的样本。

代码流程图

  1. 数据加载与预处理

    • 使用 torchvision.datasets.CIFAR10 下载并加载 CIFAR-10 数据集。
    • 对图像进行标准化处理(均值和标准差为0.5)。
  2. 生成部分标签数据集

    • PartialLabelDataset 中为每个样本生成多个候选标签(候选标签数为3),这些标签包括真实标签及随机标签。
    • 使用 generate_partial_labels() 方法生成候选标签,并打乱顺序。
  3. 数据筛选

    • 使用 filter_partial_label_dataset 函数,根据用户提供的标签列表(如 [5, 6, 7, 8, 9])过滤掉部分标签中包含这些标签的样本,创建新的数据集。
  4. 数据加载器

    • 通过 DataLoader 创建数据加载器,使得在训练过程中可以批量读取数据。
  5. 输出样本信息

    • main() 函数中打印出部分标签的一个批次示例,包括图像的形状、部分标签和真实标签。

优点和可扩展性

  1. 部分标签学习:这段代码模拟了部分标签学习的场景,其中每个样本都有多个候选标签,这为部分标签学习任务提供了一个基础框架。
  2. 灵活的标签过滤:通过 filter_partial_label_dataset 函数,用户可以方便地过滤掉特定标签的样本。
  3. 可扩展性:可以将这个框架扩展到其他数据集(如 CIFAR-100、ImageNet 等),并灵活调整候选标签大小和过滤标签。

总结

这段代码提供了一个部分标签学习框架,可以用来处理具有部分标签的不完整数据集,并提供了一种方法来筛选数据集中的特定标签。通过生成候选标签和对数据进行过滤,代码实现了部分标签学习任务的数据预处理与加载,为相关研究和应用提供了有效支持。

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 下载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 定义合并后的部分标签数据集类
class PartialLabelDataset(Dataset):
    def __init__(self, dataset, candidate_size):
        """
        初始化部分标签数据集
        :param dataset: 原始数据集对象(如 CIFAR-10)
        :param candidate_size: 候选标签集的大小
        :param filtered_labels: 不得存在于部分标签中的标签列表(可选)
        """
        self.dataset = dataset
        self.candidate_size = candidate_size
        self.num_classes = len(dataset.classes)
        self.targets = dataset.targets
        self.partial_labels = self.generate_partial_labels()
    def generate_partial_labels(self):
        """
        为每个图像生成部分标签
        :param filtered_labels: 不得存在于部分标签中的标签列表(可选)
        :return: 部分标签列表
        """
        partial_labels = []
        for target in self.targets:
            candidates = [target]
            while len(candidates) < self.candidate_size:
                random_label = np.random.randint(0, self.num_classes)
                if random_label not in candidates :
                    candidates.append(random_label)
            #打乱候选标签
            np.random.shuffle(candidates)
            partial_labels.append(candidates)
        return partial_labels
    def __getitem__(self, index):
        image, _ = self.dataset[index]
        partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)
        true_label = torch.tensor(self.targets[index], dtype=torch.long)  # 真实标签
        return image, partial_label, true_label

    def __len__(self):
        return len(self.dataset)
class FilteredPartialLabelDataset(Dataset):
    def __init__(self, dataset, partial_labels, filtered_indices):
        """
        初始化筛选后的部分标签数据集
        :param dataset: 原始数据集对象
        :param partial_labels: 完整部分标签列表
        :param filtered_indices: 筛选后的样本索引列表
        """
        self.dataset = dataset
        self.partial_labels = [partial_labels[i] for i in filtered_indices]
        self.indices = filtered_indices

    def __getitem__(self, index):
        original_index = self.indices[index] # 
        image, _ = self.dataset[original_index]
        partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)
        true_label = torch.tensor(self.dataset.targets[original_index], dtype=torch.long)  # 真实标签
        return image, partial_label, true_label  #表示这个类实例化之后,返回的就是这个样本的图像和部分标签
        

    def __len__(self):
        return len(self.indices)
def filter_partial_label_dataset(dataset, partial_labels, candidate_size=3, filtered_labels=None, batch_size=64, shuffle=True, num_workers=2):
    """
    过滤数据集以排除部分标签中含有任何 filtered_labels 的样本。

    :param dataset: 原始数据集(例如 CIFAR-10)
    :param candidate_size: 候选标签集的大小(默认:3)
    :param filtered_labels: 不得存在于部分标签中的标签列表
    :param batch_size: DataLoader 的批次大小(默认:4)
    :param shuffle: 是否在 DataLoader 中打乱数据(默认:True)
    :param num_workers: DataLoader 的工作线程数(默认:2)
    :return: (过滤后的数据集, DataLoader) 元组
    """
    if filtered_labels is None:
        raise ValueError("Filtered labels must be specified.")

    # 将部分标签转换为 NumPy 数组以进行高效过滤
    partial_labels_np = np.array(partial_labels)

    # 创建样本中不包含任何 filtered_labels 的掩码
    filtered_labels_mask = np.any(np.isin(partial_labels_np, filtered_labels), axis=1)
    final_mask = ~filtered_labels_mask  # 这个索引列中,只有不含要过滤的标签的样本才为 True

    # 获取过滤后的索引
    filtered_indices = np.where(final_mask)[0]  # 过滤后的样本的索引,每个值对是该样本在原始数据集中的索引,可以据此得到该样本的真实标签

    # 创建过滤后的部分标签数据集
    new_partial_label_dataset = FilteredPartialLabelDataset(dataset, partial_labels, filtered_indices)

    # 创建 DataLoader
    new_partial_label_loader = DataLoader(new_partial_label_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    # 打印过滤后样本的信息
    print("过滤后的样本数量:", len(filtered_indices))

    # 可选:打印一个批次的示例
    for images, partial_labels_batch , true_labels_batch in new_partial_label_loader:
        print("新图像的形状:", images.shape)
        print("新部分标签:", partial_labels_batch)
        print("新真实标签:", true_labels_batch)
        break

    return new_partial_label_dataset, new_partial_label_loader

# 主函数:生成部分标签数据集并过滤
def main():
    # 生成部分标签数据集,不包含标签5、6、7、8、9
    partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)

    # 创建 DataLoader
    trainloader = DataLoader(partial_label_dataset, batch_size=4, shuffle=True, num_workers=2)

    # 打印部分标签示例
    for images, partial_labels, true_labels in trainloader:
        print("图像的形状:", images.shape)
        print("部分标签:", partial_labels)
        print("真实标签:", true_labels)
        break

if __name__ == '__main__':
    main()
    partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)
    partial_labels = partial_label_dataset.generate_partial_labels()
    filter_partial_label_dataset(trainset, partial_labels, candidate_size=3, filtered_labels=[5, 6, 7, 8, 9], batch_size=4, shuffle=True, num_workers=2)

http://www.kler.cn/a/593932.html

相关文章:

  • AcWing 838:堆排序 ← 数组模拟
  • 双碳战略下的电能质量革命:解码电力系统的健康密码
  • oracle 索引
  • 世界职业院校技能大赛(软件测试)技术创新思路分享(二)
  • VSCode C/C++ 开发环境完整配置及常见问题
  • Android Launcher3终极改造:全屏应用展示实战!深度解析去除Hotseat的隐藏技巧
  • 数据结构之栈(C语言)
  • 轨道交通DSP+FPGA主控板(6U)板卡,支持逻辑控制、数据处理、通信管理、系统安全保护切换等功能
  • NET6 WebApi第5讲:中间件(源码理解,俄罗斯套娃怎么来的?);Web 服务器 (Nginx / IIS / Kestrel)、WSL、SSL/TSL
  • 【01-驱动学习】
  • 华为流程体系建设与运营(123页PPT)(文末有下载方式)
  • 【Spring 默认是否管理 Request 和 Session Bean 的生命周期?】
  • Android Coil3 Fetcher preload批量Bitmap拼接扁平宽图,Kotlin
  • 头歌 JAVA 桥接模式实验
  • GitHub Actions上关于“Cannot Find Matching Keyid”或“Corepack/PNPM Not Found”的错误
  • 英伟达消费级RTX显卡配置表
  • Linux 驱动开发笔记--1.驱动开发的引入
  • 基于51单片机的多路数据采集系统proteus仿真
  • 效用系统简介
  • bootstrap介绍(前端框架)(提供超过40种可复用组件,从导航栏到轮播图,从卡片到弹窗)bootstrap框架