PyTorch torch.utils.checkpoint

2025-07-02 14:22 更新

PyTorch 检查点机制详解:优化显存使用与模型训练效率

一、检查点机制是什么?

检查点(Checkpoint)机制是一种用于优化深度学习模型训练过程中显存使用的技巧。在训练复杂的深度学习模型时,尤其是大型神经网络,显存资源往往非常有限。检查点机制通过在正向传播过程中丢弃某些中间激活结果,然后在反向传播过程中重新计算这些中间结果,从而减少显存占用。

二、PyTorch 检查点函数详解

(一)torch.utils.checkpoint.checkpoint(function, *args, preserve_rng_state=True)

  1. 基本原理
    • 在正向传播阶段,function 会以 torch.no_grad() 模式运行,即不保存中间激活结果。仅保存输入张量和 function 参数。
    • 在反向传播阶段,通过重新运行 function 来重新计算中间激活结果,然后基于这些结果计算梯度。

  1. 参数说明
    • function:定义模型正向传播过程的函数。该函数应能够处理输入元组并正确执行前向计算。
    • args:传递给 function 的输入张量元组。
    • preserve_rng_state:布尔值,默认为 True。如果为 True,则在检查点过程中保存并恢复随机数生成器(RNG)状态,以确保使用随机操作(如 dropout)时结果的确定性。

  1. 注意事项
    • 检查点机制不支持 torch.autograd.grad(),仅支持 torch.autograd.backward()
    • 如果反向传播期间的 function 调用与正向传播期间的调用存在差异(例如由于全局变量的影响),则可能导致结果不一致。

(二)torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, preserve_rng_state=True)

  1. 基本原理
    • 适用于顺序执行的模型或模块列表。将模型划分为多个段,每个段对应一个检查点。
    • 除最后一个段外,其他段均以 torch.no_grad() 模式运行,不保存中间激活结果。每个检查点段的输入会被保存,以便在反向传播时重新计算该段的正向结果。

  1. 参数说明
    • functions:一个 torch.nn.Sequential 对象或包含多个模块 / 函数的列表。
    • segments:模型被划分为的段数。
    • inputs:传递给 functions 的输入张量元组。
    • preserve_rng_state:布尔值,默认为 True。是否在每个检查点期间保存和恢复 RNG 状态。

三、实际应用案例

(一)单个模块的检查点应用

假设我们有一个简单的神经网络模块,我们希望对该模块应用检查点以减少显存占用。

  1. import torch
  2. import torch.nn as nn
  3. import torch.utils.checkpoint as cp
  4. class CheckpointModel(nn.Module):
  5. def __init__(self):
  6. super(CheckpointModel, self).__init__()
  7. self.layer1 = nn.Linear(10, 10)
  8. self.layer2 = nn.Linear(10, 10)
  9. self.layer3 = nn.Linear(10, 2)
  10. def forward(self, x):
  11. # 对 layer2 应用检查点
  12. x = self.layer1(x)
  13. x = cp.checkpoint(self.layer2, x)
  14. x = self.layer3(x)
  15. return x
  16. model = CheckpointModel()
  17. input_var = torch.randn(1, 10)
  18. output = model(input_var)

(二)顺序模型的检查点应用

对于顺序执行的模型,我们可以使用 checkpoint_sequential 来划分检查点段。

  1. model = nn.Sequential(
  2. nn.Linear(10, 10),
  3. nn.ReLU(),
  4. nn.Linear(10, 10),
  5. nn.ReLU(),
  6. nn.Linear(10, 2)
  7. )
  8. input_var = torch.randn(1, 10)
  9. segments = 2 # 将模型划分为 2 个段
  10. output = cp.checkpoint_sequential(model, segments, input_var)

四、性能与显存权衡

使用检查点机制虽然可以有效减少显存占用,但会增加计算时间,因为需要在反向传播过程中重新计算中间激活结果。在实际应用中,需要根据模型规模、显存限制和训练时间要求等因素,合理选择是否应用检查点机制以及如何划分检查点段。

五、总结

通过本教程,我们详细介绍了 PyTorch 中的检查点机制及其应用方法。检查点机制在训练大型深度学习模型时,能够有效减少显存占用,提高模型训练的可行性。正确理解和使用检查点机制,可以帮助我们在有限的硬件资源下训练更复杂的模型。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号