当前位置: 首页 > article >正文

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)的计算过程

        cifar10数据集的众多demo中,在数据加载环节,transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)这条指令是经常看到的。这是一个 PyTorch 中用于图像数据标准化的函数调用,它将图像的每个通道的值进行标准化处理,使得数据的均值变为 (0.4914, 0.4822, 0.4465),标准差变为 (0.2023, 0.1994, 0.2010)。
        关于均值、均方差以及标准化函数transforms.Normalize()的文章太多了,这里记录一下计算过程。
        对于 CIFAR-10 数据集,均值和标准差的计算方法如下:
        1、收集数据集: 首先,你需要加载整个 CIFAR-10 数据集。CIFAR-10 数据集包含 60,000 张 32x32 的彩色图像,分为 10 个类别。
        2、计算每个通道的均值:
        
对于每个图像,将 RGB 三个通道的值提取出来。然后对所有图像的每个通道的像素值求和,然后除以总像素数(图像数量乘以每个图像的像素数)。
        3、计算每个通道的标准差:
        
对于每个图像,计算每个通道的像素值与该通道均值的差的平方。再对所有图像的每个通道的平方差求和,然后除以总像素数,最后取平方根。

import torch
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor()
])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)

# 将数据集转换为Tensor
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)

# 初始化均值和标准差
mean = torch.zeros(3)
std = torch.zeros(3)

# 计算均值和标准差
for images, _ in train_loader:
    for i in range(3):  # 遍历RGB三个通道
        mean[i] += images[:, i, :, :].mean()   # 计算每个通道的均值
        std[i] += images[:, i, :, :].std()     # 计算每个通道的标准差

# 对三个通道的均值和标准差求平均
mean /= 3
std /= 3

# 计算平均值
mean /= len(train_loader)
std /= len(train_loader)

print(f'均值: {mean}')   # 均值: tensor([0.4914, 0.4822, 0.4465])
print(f'标准差: {std}')  # 标准差: tensor([0.2023, 0.1994, 0.2010])

上述代码稍加改造,就可用于自定义数据集的计算:

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os


# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir   # 图片文件夹的路径
        self.transform = transform   # 数据预处理
        self.img_files = os.listdir(img_dir)  # 图片文件列表

    def __len__(self):   # 获取数据集大小
        return len(self.img_files)

    def __getitem__(self, idx):  # 获取图片数据
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor()
])

# 创建自定义数据集实例
custom_dataset = CustomDataset(img_dir='自定义数据集的文件夹路径', transform=transform)

# 创建数据加载器
custom_loader = DataLoader(custom_dataset, batch_size=1, shuffle=False)

# 初始化均值和标准差
mean = torch.zeros(3)
std = torch.zeros(3)

# 计算均值和标准差
for images in custom_loader:
    for i in range(3):  # 遍历RGB三个通道
        mean[i] += images[:, i, :, :].mean()  # 计算每个通道的均值
        std[i] += images[:, i, :, :].std()  # 计算每个填充的标准差

# 计算平均值
mean /= len(custom_loader)
std /= len(custom_loader)

print(f'均值: {mean}')
print(f'标准差: {std}')


http://www.kler.cn/news/361956.html

相关文章:

  • 亿佰特STM32MP13工业核心板【学习】
  • 深入解析 Jenkins 自动化任务链:三大方法实现任务间依赖与状态控制
  • R语言笔记(一)
  • 【小沐学Golang】基于Go语言搭建静态文件服务器
  • word表格跨页后自动生成的顶部横线【去除方法】
  • 圆周率的估算
  • qt5.12.12插件机制无法加载插件问题
  • 毕业生找工作的攻略:从校园到职场的成功之路
  • R语言绘图——文本注释
  • LLaMA Factory环境配置
  • 猎板高频PCB的制成能力分享
  • CISAW安全集成,协助组织构建坚固的信息防护堡垒
  • 【一站式学会Kotlin】第二十五 Kotlin内部类和嵌套类的区别和案例
  • 智慧交通新征程:亿维锐创与图为科技达成战略合作
  • STM32+CubeMX -- 开发辅助工具
  • 蓝桥杯基本操作和运算
  • 【某农业大学计算机网络实验报告】实验一 集线器和交换机的对比
  • excel将文本型数字转变为数值型数字
  • ppt模板一键套用怎么操作?制作ppt基础步骤手把手教你
  • Java中的异步编程模型
  • LN9361 低噪声电荷泵 DC/DC 转换器
  • 什么是缓存?
  • 群控系统服务端开发模式-业务流程图补充
  • Android使用协程实现自定义Toast
  • 一、python基础
  • IDE使用技巧与插件推荐