PyTorch CPU 线程与 TorchScript 推断优化详解

2025-06-24 15:06 更新

在深度学习模型的部署和推断阶段,合理利用 CPU 线程和优化 TorchScript 推断性能是提升整体效率的关键。本文将深入浅出地讲解 PyTorch CPU 线程和 TorchScript 推断的相关知识,并通过实例帮助你掌握优化技巧。

一、PyTorch CPU 线程与并行机制

1.1 CPU 线程的并行级别

PyTorch 在模型推断过程中提供了不同级别的并行机制:

  • 推理线程级并行:多个推理线程可以同时执行模型的前向传播。
  • 操作间并行:通过 torch.jit._fork()torch.jit._wait(),可以在不同操作之间实现并行执行。
  • 操作内并行:在单个操作内部(如大张量元素操作、卷积等),利用多个 CPU 线程加速计算。

代码示例 1:操作间并行

import torch
import torch.jit as jit


@jit.script
def compute_z(x, w_z):
    return torch.mm(x, w_z)


@jit.script
def forward(x, w_y, w_z):
    # 异步启动 compute_z
    fut = jit._fork(compute_z, x, w_z)
    # 并行执行其他操作
    y = torch.mm(x, w_y)
    # 等待并获取结果
    z = jit._wait(fut)
    return y + z


## 使用示例
x = torch.randn(3, 3)
w_y = torch.randn(3, 3)
w_z = torch.randn(3, 3)
result = forward(x, w_y, w_z)
print(result)

1.2 并行后端的选择

PyTorch 支持多种并行后端,包括 OpenMP 和 TBB(Intel Threading Building Blocks)。不同的后端适用于不同的场景:

  • OpenMP:适用于基于循环的并行操作,广泛支持但可能存在线程池互操作性问题。
  • TBB:适用于任务调度和高并发场景,保证单个进程内使用统一的线程池。

构建选项

构建选项 备注
ATen ATEN_THREADING OMP(默认),TBB
MKL MKL_THREADING 同上 需要 BLAS=MKL 启用
MKL-DNN MKLDNN_THREADING 同上 需要 USE_MKLDNN=1 启用

注意:强烈建议不要在同一构建中混用 OpenMP 和 TBB。

二、控制线程设置的运行时 API

PyTorch 提供了多种运行时 API 来控制线程设置,帮助我们根据实际需求动态调整性能。

2.1 互操作并行(Inter-Operation Parallelism)

互操作并行控制多个推理任务之间的并行度。我们可以通过以下 API 进行设置:

  • torch.set_num_interop_threads(n):设置互操作线程数。
  • torch.get_num_interop_threads():获取当前互操作线程数。

代码示例 2

## 设置互操作线程数为 2
torch.set_num_interop_threads(2)
print("Interop threads:", torch.get_num_interop_threads())

2.2 操作内并行(Intra-Operation Parallelism)

操作内并行控制单个操作内部的线程使用情况。我们可以通过以下方式进行设置:

  • torch.set_num_threads(n):设置操作内线程数。
  • torch.get_num_threads():获取当前操作内线程数。
  • 环境变量:OMP_NUM_THREADSMKL_NUM_THREADS

优先级说明torch.set_num_threads() 优先级高于环境变量,而 MKL_NUM_THREADS 优先级高于 OMP_NUM_THREADS

代码示例 3

## 设置操作内线程数为 4
torch.set_num_threads(4)
print("Intra-op threads:", torch.get_num_threads())

三、线程设置的调试与优化

3.1 打印线程设置信息

PyTorch 提供了 torch.__config__.parallel_info() 方法,用于打印当前的线程设置信息,帮助我们进行调试和优化。

代码示例 4

print(torch.__config__.parallel_info())

输出示例

Parallel information:
    Intra-op parallelism: OPENMP with 4 threads
    Inter-op parallelism: 2 threads

3.2 线程设置的优化策略

根据实际应用场景,我们可以采取以下优化策略:

  • 推理场景:适当减少互操作线程数和操作内线程数,避免线程切换开销。
  • 训练场景:可以适当增加线程数,充分利用 CPU 资源。
  • 资源受限场景:根据可用 CPU 核心数,合理分配线程数,避免过度使用。

四、TorchScript 推断优化实践

4.1 TorchScript 简介

TorchScript 是 PyTorch 的一种中间表示形式,用于将 Python 定义的模型转换为可以在不同环境中高效运行的格式。它支持即时编译(JIT)和静态编译,能够显著提升推断性能。

4.2 TorchScript 推断的线程优化

在使用 TorchScript 进行推断时,合理设置线程数可以显著提升性能。以下是一个完整的优化流程:

步骤 1:模型转换为 TorchScript

## 定义模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(10, 2)


    def forward(self, x):
        return self.fc(x)


## 创建模型实例并转换为 TorchScript
model = SimpleModel()
traced_model = torch.jit.trace(model, torch.randn(1, 10))
traced_model.save("model.pt")

步骤 2:加载 TorchScript 模型并设置线程数

## 加载 TorchScript 模型
loaded_model = torch.jit.load("model.pt")


## 设置线程数
torch.set_num_threads(4)
torch.set_num_interop_threads(2)


## 执行推断
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
print(output)

4.3 性能对比

通过对比不同线程设置下的推断性能,我们可以找到最佳配置。以下是一个简单的性能测试代码:

import time


def benchmark(model, input_data, num_runs=1000):
    # 预热
    for _ in range(100):
        model(input_data)


    # 测试
    start_time = time.time()
    for _ in range(num_runs):
        model(input_data)
    end_time = time.time()


    return (end_time - start_time) / num_runs


## 测试不同线程设置的性能
input_data = torch.randn(1, 10)


## 设置 1 个操作内线程,1 个互操作线程
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
time_1 = benchmark(loaded_model, input_data)
print(f"1 线程:{time_1:.6f} 秒/次")


## 设置 4 个操作内线程,2 个互操作线程
torch.set_num_threads(4)
torch.set_num_interop_threads(2)
time_4 = benchmark(loaded_model, input_data)
print(f"4 线程:{time_4:.6f} 秒/次")


## 设置 8 个操作内线程,4 个互操作线程
torch.set_num_threads(8)
torch.set_num_interop_threads(4)
time_8 = benchmark(loaded_model, input_data)
print(f"8 线程:{time_8:.6f} 秒/次")

输出示例

1 线程:0.000321 秒/次
4 线程:0.000215 秒/次
8 线程:0.000232 秒/次

从结果可以看出,在本例中,设置 4 个操作内线程和 2 个互操作线程时性能最佳。

五、案例分析与总结

5.1 案例背景

在实际项目中,我们通常需要在 CPU 上高效运行深度学习模型,尤其是在资源受限的环境中(如边缘设备)。通过合理设置 CPU 线程和优化 TorchScript 推断,可以显著提升模型的运行效率。

5.2 案例总结

通过本文的介绍和实例,我们总结出以下关键点:

  • 合理利用 PyTorch 的并行机制(推理线程级并行、操作间并行、操作内并行)可以显著提升模型推断性能。
  • 根据实际场景选择合适的并行后端(OpenMP 或 TBB),并在构建 PyTorch 时正确配置。
  • 使用运行时 API 动态调整线程设置,并通过 torch.__config__.parallel_info() 进行调试。
  • 在 TorchScript 推断中,通过实验找到最佳的线程配置,平衡性能和资源使用。

## 设置线程数
programming_lion_threads = 4
w3cschool_interop_threads = 2
torch.set_num_threads(programming_lion_threads)
torch.set_num_interop_threads(w3cschool_interop_threads)

六、常见问题解答

Q1:如何确定最佳的线程设置?

A1:最佳线程设置取决于具体的应用场景和硬件环境。建议通过实验测试不同的线程配置,找到性能最优的组合。可以从操作内线程数 4、互操作线程数 2 开始测试,逐步调整。

Q2:TorchScript 推断是否支持 GPU 加速?

A2:是的,TorchScript 推断不仅支持 CPU,还支持 GPU 加速。在有 GPU 的环境中,可以通过 torch.cuda 相关 API 将模型和数据移动到 GPU 上进行计算。

Q3:如何进一步提升 TorchScript 推断性能?

A3:除了线程设置优化外,还可以尝试以下方法:

  • 使用 torch.jit.freeze() 冻结模型,减少运行时开销。
  • 使用 torch.jit.optimize_for_inference() 对模型进行优化。
  • 在支持的平台上使用量化(Quantization)技术减少模型大小并提升推理速度。

七、总结与展望

PyTorch 提供了灵活且强大的 CPU 线程控制和 TorchScript 推断优化机制。通过合理设置线程数、选择合适的并行后端以及优化 TorchScript 模型,我们可以在不同硬件环境下实现高效的模型推断。

对于初学者,建议从简单的模型开始,逐步尝试不同的线程配置和优化方法,观察性能变化。同时,关注 PyTorch 官方文档和社区动态,及时了解最新的性能优化技术。

关注编程狮(W3Cschool)平台,获取更多深度学习模型优化教程和案例,让你的模型在实际应用中表现更佳!

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号