当前位置: 首页 > article >正文

torch 如何实现两点分布采样,100个样本中20个样本为1,80个为2

方法1

可以使用 PyTorch 中的 torch.distributions 模块实现两点分布采样。具体来说,可以使用 Categorical 分布将数字1和数字2的概率分别设为0.2和0.8,然后调用 sample() 方法进行采样。

下面是实现上述功能的代码示例:

import torch
# 创建 Categorical 分布(数字1和数字2的概率分别设为0.2和0.8)
probs = torch.tensor([0.2, 0.8])
dist = torch.distributions.Categorical(probs)
# 从分布中采样100个样本
samples = dist.sample(torch.Size([100]))
# 统计样本中数字1和数字2的数量
count_1 = torch.sum(samples == 0)
count_2 = torch.sum(samples == 1)
print(f"数字1的数量:{count_1}")
print(f"数字2的数量:{count_2}")

输出结果类似于:

数字1的数量:22
数字2的数量:78

方法2

可以先使用 torch.ones()torch.zeros() 函数生成分别包含20个数字1和80个数字2的张量,然后使用 torch.cat() 函数将它们拼接在一起,再使用 torch.randperm() 函数对其进行打乱。

下面是实现上述功能的代码示例:

import torch
# 生成包含20个数字1和80个数字2的张量,并拼接在一起
ones_tensor = torch.ones(20)
zeros_tensor = torch.zeros(80)
data_tensor = torch.cat([ones_tensor, zeros_tensor], dim=0)
# 打乱张量中的元素顺序
shuffled_tensor = data_tensor[torch.randperm(data_tensor.shape[0])]
print(shuffled_tensor)

输出结果为:

tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1.,
        1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

其中,数字1被表示为1.0,数字2被表示为2.0。


http://www.kler.cn/news/160000.html

相关文章:

  • Docker-多容器应用
  • 算法题:买汽水(瓶子瓶盖换水)
  • Linux下的查看文件的命令
  • 面试被问到 HTTP和HTTPS的区别有哪些?你该如何回答~
  • 制作一个RISC-V的操作系统五-RISC-V汇编语言编程三
  • Python-炸弹人【附完整源码】
  • 【C/C++指针】指针*与引用的区别
  • 12.06 二叉树中等题2
  • 安网AC智能路由系统actpt_5g.data敏感信息泄露漏洞复现 [附POC]
  • 表单修改时取消disabled snippet
  • 【节日专栏】Python海龟绘制圣诞树代码
  • 0X05
  • 一、CSharp_Basic:什么是.Net平台?什么是.Net FrameWork?什么是C#?
  • C# Solidworks二次开发:获取零件的最小包容体方法详解
  • 关于mysql的lower_case_table_names引发的思考
  • C语言词法陷阱
  • 《C++ primer》 anki学习卡片txt输出101张,更新至第2章,截止2023年12月6日
  • 计算机操作系统3
  • C语言猜数字小游戏
  • java单人聊天
  • 模式识别与机器学习(七):集成学习
  • Python高级数据结构——并查集(Disjoint Set)
  • Multidimensional Scaling(MDS多维缩放)算法及其应用
  • docker安装mysql8
  • Python 模块的使用方法
  • 万宾科技监测设备,可燃气体监测仪特点一览
  • PostgreSQL有意思的现象:支持不带列的表
  • Java 数据结构篇-用链表、数组实现队列(数组实现:循环队列)
  • 【动手学深度学习】(六)权重衰退
  • 【Unity入门】声音组件AudioSource简介及实现声音的近大远小