当前位置: 首页 > article >正文

【代码模板】统计数据集的均值和标准差

背景

在数据预处理时,通常会对样本进行标准化操作,使样本的均值为0,标准差为1,从而提高训练的稳定性。

进行标准化操作时,需要预先统计数据集的均值和标准差。下面的demo展示了如何实现这个操作。

demo

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_set = datasets.CIFAR10(
    root="dataset/", transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)


def get_mean_std(loader):
    # var[X] = E[X**2] - E[X]**2
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(loader):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean**2) ** 0.5

    return mean, std


mean, std = get_mean_std(train_loader)
print(mean)
print(std)

参考资料

Pytorch Quick Tip: Calculate Mean and Standard Deviation of Data


http://www.kler.cn/news/334566.html

相关文章:

  • C++面试速通宝典——16
  • Spring Boot大学生就业招聘系统的设计与优化
  • 9.29 LeetCode 3304、3300、3301
  • Kubernetes云原生存储解决方案之 Rook Ceph实践探究
  • 【可答疑】基于51单片机的智能台灯(含仿真、代码、报告、演示视频等)
  • 用Spring Boot搭建IT交流与学习平台
  • 机器学习系列篇章0 --- 人工智能机器学习相关概念梳理
  • 【复习】CSS中的选择器
  • 定时器TIM输出比较及其应用
  • 幂等性接口实现
  • 前端公共资源CDN存储库大全
  • Pikachu-unsafe upfileupload-getimagesize
  • 【深度学习】— softmax回归、网络架构、softmax 运算、小批量样本的向量化、交叉熵
  • 【C++ STL】手撕vector,深入理解vector的底层
  • 【分布式微服务云原生】掌握分布式缓存:Redis与Memcached的深入解析与实战指南
  • 【RabbitMq源码阅读】分析RabbitMq发送消息源码
  • stm32定时器中断和外部中断
  • 深入探讨指令调优的局限性
  • 删除GitHub仓库的fork依赖 (Delete fork dependency of a GitHub repository)
  • 简单介绍Wiki和历史