PyTorch 概率分布-torch分布

2025-06-25 14:56 更新

一、PyTorch 概率分布概述

PyTorchtorch.distributions 包提供了丰富的概率分布类和采样函数,可用于构建随机计算图和实现随机梯度估计器。以下是常用分布及其关键方法的介绍。

(一)常用分布类

  1. 正态分布(Normal)
    • 可通过均值(loc)和标准差(scale)参数化。
    • 提供 sample()rsample()log_prob() 等方法,分别用于采样、可微采样和计算对数概率密度。

  1. 伯努利分布(Bernoulli)
    • 由成功概率(probs)或对数几率(logits)参数化。
    • 常用于二分类问题中的随机采样。

  1. 均匀分布(Uniform)
    • 在指定区间 [low, high) 内生成均匀分布的随机样本。

  1. 分类分布(Categorical)
    • 适用于多分类问题,可基于类别概率(probs)或对数几率(logits)进行采样。

(二)关键方法

  1. sample() :从分布中生成随机样本。
  2. rsample() :生成可微样本,利用重参数化技巧实现梯度回传。
  3. log_prob(value) :计算给定值的对数概率密度或质量。
  4. entropy() :计算分布的熵。

二、实际案例:强化学习中的策略梯度

假设我们正在开发一个强化学习模型,用于在模拟环境中训练智能体。我们将利用 PyTorch 的概率分布实现策略梯度方法。

import torch
import torch.distributions as td


## 定义一个简单的策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)


    def forward(self, x):
        return torch.softmax(self.fc(x), dim=-1)


## 初始化策略网络
policy_net = PolicyNetwork(input_dim=4, output_dim=2)


## 模拟环境状态
state = torch.tensor([0.1, 0.2, 0.3, 0.4])


## 使用策略网络输出动作概率
probs = policy_net(state)


## 创建分类分布
m = td.Categorical(probs=probs)


## 采样动作
action = m.sample()


## 计算动作的对数概率
log_prob = m.log_prob(action)


## 假设获得奖励
reward = torch.tensor(1.0)


## 计算损失并反向传播
loss = -log_prob * reward
loss.backward()

三、实际案例:变分自编码器中的重参数化

在变分自编码器(VAE)中,我们利用重参数化技巧实现路径导数估计器。

import torch
import torch.distributions as td


## 定义一个简单的变分自编码器
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Linear(input_dim, latent_dim * 2)
        self.decoder = nn.Linear(latent_dim, input_dim)


    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


    def forward(self, x):
        mu, logvar = torch.chunk(self.encoder(x), 2, dim=-1)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decoder(z)
        return reconstructed, mu, logvar


## 初始化 VAE
vae = VAE(input_dim=784, latent_dim=20)


## 输入数据
x = torch.randn(1, 784)


## 前向传播
reconstructed, mu, logvar = vae(x)


## 定义损失函数
def vae_loss(reconstructed, x, mu, logvar):
    reconstruction_loss = nn.MSELoss()(reconstructed, x)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence


## 计算损失
loss = vae_loss(reconstructed, x, mu, logvar)
loss.backward()

四、总结

本教程为零基础的初学者详细讲解了 PyTorch 中的概率分布,包括常用分布及其关键方法。通过实际案例,展示了如何在强化学习和变分自编码器中应用这些分布。希望读者能通过这些知识,充分利用 PyTorch 的概率分布功能,加速深度学习项目。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号