PyTorch 概率分布-torch分布
2025-06-25 14:56 更新
一、PyTorch 概率分布概述
PyTorch 的 torch.distributions
包提供了丰富的概率分布类和采样函数,可用于构建随机计算图和实现随机梯度估计器。以下是常用分布及其关键方法的介绍。
(一)常用分布类
- 正态分布(Normal)
- 可通过均值(
loc
)和标准差(scale
)参数化。 - 提供
sample()
、rsample()
、log_prob()
等方法,分别用于采样、可微采样和计算对数概率密度。
- 可通过均值(
- 伯努利分布(Bernoulli)
- 由成功概率(
probs
)或对数几率(logits
)参数化。 - 常用于二分类问题中的随机采样。
- 由成功概率(
- 均匀分布(Uniform)
- 在指定区间
[low, high)
内生成均匀分布的随机样本。
- 在指定区间
- 分类分布(Categorical)
- 适用于多分类问题,可基于类别概率(
probs
)或对数几率(logits
)进行采样。
- 适用于多分类问题,可基于类别概率(
(二)关键方法
sample()
:从分布中生成随机样本。rsample()
:生成可微样本,利用重参数化技巧实现梯度回传。log_prob(value)
:计算给定值的对数概率密度或质量。entropy()
:计算分布的熵。
二、实际案例:强化学习中的策略梯度
假设我们正在开发一个强化学习模型,用于在模拟环境中训练智能体。我们将利用 PyTorch 的概率分布实现策略梯度方法。
import torch
import torch.distributions as td
## 定义一个简单的策略网络
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(PolicyNetwork, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return torch.softmax(self.fc(x), dim=-1)
## 初始化策略网络
policy_net = PolicyNetwork(input_dim=4, output_dim=2)
## 模拟环境状态
state = torch.tensor([0.1, 0.2, 0.3, 0.4])
## 使用策略网络输出动作概率
probs = policy_net(state)
## 创建分类分布
m = td.Categorical(probs=probs)
## 采样动作
action = m.sample()
## 计算动作的对数概率
log_prob = m.log_prob(action)
## 假设获得奖励
reward = torch.tensor(1.0)
## 计算损失并反向传播
loss = -log_prob * reward
loss.backward()
三、实际案例:变分自编码器中的重参数化
在变分自编码器(VAE)中,我们利用重参数化技巧实现路径导数估计器。
import torch
import torch.distributions as td
## 定义一个简单的变分自编码器
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = nn.Linear(input_dim, latent_dim * 2)
self.decoder = nn.Linear(latent_dim, input_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = torch.chunk(self.encoder(x), 2, dim=-1)
z = self.reparameterize(mu, logvar)
reconstructed = self.decoder(z)
return reconstructed, mu, logvar
## 初始化 VAE
vae = VAE(input_dim=784, latent_dim=20)
## 输入数据
x = torch.randn(1, 784)
## 前向传播
reconstructed, mu, logvar = vae(x)
## 定义损失函数
def vae_loss(reconstructed, x, mu, logvar):
reconstruction_loss = nn.MSELoss()(reconstructed, x)
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return reconstruction_loss + kl_divergence
## 计算损失
loss = vae_loss(reconstructed, x, mu, logvar)
loss.backward()
四、总结
本教程为零基础的初学者详细讲解了 PyTorch 中的概率分布,包括常用分布及其关键方法。通过实际案例,展示了如何在强化学习和变分自编码器中应用这些分布。希望读者能通过这些知识,充分利用 PyTorch 的概率分布功能,加速深度学习项目。
以上内容是否对您有帮助:
更多建议: