当前位置: 首页 > article >正文

PyTorch——从入门到精通:PyTorch基础知识(normal 函数)【PyTorch系统学习】

torch.normal() 的用法

        该函数的参数如下:

normal(mean, std, *, generator=None, out=None)

        参数说明

  1. mean:

    • 均值,可以是一个数值(标量)或者张量。
    • 如果是张量,则指定生成正态分布的均值,形状需与标准差匹配
  2. std:

    • 标准差,可以是一个数值(标量)或者张量。
    • 如果是张量,则指定生成正态分布的标准差,形状需与均值匹配
  3. generator (可选): 用于生成随机数的随机数生成器。

  4. out (可选): 如果提供,则结果存储在这个张量中。

        返回值

        返回一个与 meanstd 的形状匹配的张量,值是符合指定正态分布的随机数。

        函数使用示例

        生成一个标量值

import torch
# 生成均值为0,标准差为1的正态分布随机数
result = torch.normal(0, 1)
print(result)

        生成一个张量

import torch
# 生成均值为0,标准差为1的3x3张量
result = torch.normal(0, 1, size=(3, 3))
print(result)

        不同均值和标准差的张量

import torch
# 均值和标准差为张量
mean = torch.tensor([0.0, 1.0, 2.0])
std = torch.tensor([1.0, 0.5, 0.25])
result = torch.normal(mean, std)
print(result)

        高维张量的均值和标准差

        当函数的作用对象拓展到高维张量时,相信还有很多小伙伴不太理解,torch.normal()函数中的均值和标准差是如何体现和作用的,接下来就针对于不同情况来细致的讲解一下:

        情况一:当 meanstd 是标量时,生成的矩阵中所有元素都使用同一个均值和标准差。

import torch

# 生成一个 3x3 的矩阵,每个元素均值为 0,标准差为 1
matrix = torch.normal(0, 1, size=(3, 3))
print(matrix)

        例如,在上述的代码中,虽然我们生成的是一个3✖3的矩阵,但由于我们的均值和标准差都是标量,因此每个元素都是从均值为 0、标准差为 1 的正态分布中独立采样。

        情况二:当 meanstd 是张量时,生成的矩阵中每个位置的元素使用对应位置的均值和标准差。

import torch

# 均值和标准差为矩阵
mean = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
std = torch.tensor([[0.1, 0.2], [0.3, 0.4]])

# 生成矩阵
matrix = torch.normal(mean, std)
print(matrix)

        这个时候,对于生成的矩阵matrix:

  • matrix[0,0] 从 N( 1.0, 0.1^2) 中采样。
  • matrix[0,1] 从 N( 2.0, 0.2^2) 中采样。
  • matrix[1,0] 从 N( 3.0, 0.3^2) 中采样。
  • matrix[1,1] 从 N( 4.0, 0.4^2) 中采样。

        这也是为什么在参数说明中强调均值和标准差的形状需要相互匹配,并且返回值的形状在默认情况下会与 meanstd 的形状一致,在这道题中的形状即为(2,2)。

        情况三:如果 meanstd 的形状不同,但满足广播规则,PyTorch 会自动扩展较小的张量以匹配较大的张量。广播机制的介绍曾在先前的博客(PyTorch基础知识(张量))中有所涉及。

import torch

mean = torch.tensor([1.0, 2.0])  # 1D 张量,形状为 (2,)
std = torch.tensor([[0.1], [0.2]])  # 2D 张量,形状为 (2, 1)

# 广播机制
matrix = torch.normal(mean, std)
print(matrix)

广播过程

  1. mean 的形状扩展为 (2, 2),则拓展后的均值为[[1, 2],[1,2]]。
  2. std 的形状扩展为 (2, 2),则拓展后的标准差为[[ 0.1, 0.1 ], [ 0.2, 0.2]]

结果

  • matrix[0,0]  从 N(1.0,0.1^2) 中采样。
  • matrix[0,1]  从 N(2.0,0.1^2) 中采样。
  • matrix[1,0]  从 N(1.0,0.2^2) 中采样。
  • matrix[1,1]  从 N(2.0,0.2^2) 中采样。

感谢阅读,希望对你有所帮助~


http://www.kler.cn/a/404413.html

相关文章:

  • 北京申请中级职称流程(2024年)
  • [C++]:C++11(三)
  • 实验室管理效率提升:Spring Boot技术的力量
  • 什麼是ISP提供的公共IP地址?
  • 【IDEA】解决总是自动导入全部类(.*)问题
  • golang调用webview,webview2,go-webview2
  • 【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,2-30
  • CSS中calc语法不生效
  • Android 从本地选择视频,用APP播放或进行其他处理
  • 缓冲区的奥秘:解析数据交错的魔法
  • C#(12) 内部类和分部类
  • 弹幕发送功能‘简单’实现
  • 数据集论文:面向深度学习的土地利用场景分类与变化检测
  • 设计模式-Adapter(适配器模式)GO语言版本
  • 2024信创数据库TOP30之达梦DM8
  • php:nginx如何配置WebSocket代理?
  • 接雨水
  • 智能工厂的设计软件 为了监管控一体化的全能Supervisor 的监督学习 之 序8 进化论及科学的信息技术创新:分布式账本/区块链/智能合约 之2
  • yolov5 数据集分享:纯干货
  • GEE 训练教程——Sentinel-1的卷积(核函数)的分析和可视化
  • this.$prompt 限制输入长度
  • Windows环境GeoServer打包Docker极速入门
  • 出海第一步:搞定业务系统的多区域部署
  • 大模型-微调与对齐-非强化学习的对齐方法
  • CSS3 动画:前端开发的动态美
  • 实现了图像处理、绘制三维坐标系以及图像合成的操作