PyTorch distributions模块介绍
PyTorch 的 torch.distributions
模块提供了对概率分布的全面支持,允许用户通过对象化的方式定义、操作和采样各种常见分布。该模块适用于概率建模、生成模型(如变分自动编码器 VAE)、强化学习等需要使用分布的场景。每种分布都有通用的接口来计算概率、对数概率、采样等。
常用的分布类型
以下是 torch.distributions
模块中一些常见的分布:
Categorical
- 离散分类分布,可以用于从多类别中采样。
- 用法:输入类别的概率
probs
或对数概率logits
。
from torch.distributions import Categorical
probs = torch.tensor([0.2, 0.5, 0.3])
dist = Categorical(probs)
sample = dist.sample() # 从 [0, 1, 2] 中采样
Normal
- 正态分布(高斯分布),常用于生成连续值或噪声。
- 用法:需要指定均值
mean
和标准差std
。
from torch.distributions import Normal
dist = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
sample = dist.sample((5,)) # 从正态分布中采样 5 个样本