PyTorch torch.utils.cpp_extension

2025-07-02 14:27 更新

PyTorch C++ 扩展开发详解

一、为什么需要 C++ 扩展?

在深度学习开发中,我们有时会遇到 Python 实现的性能瓶颈,或者需要调用现有的 C++_nlank 或 CUDA 代码。PyTorch 提供了强大的 C++ 扩展功能,允许开发者将自定义的 C++ 或 CUDA 代码集成到 PyTorch 模型中,从而提升性能或复用代码。

二、PyTorch C++ 扩展的核心工具

(一)torch.utils.cpp_extension.CppExtension

用于创建一个 setuptools.Extension,用于构建 C++ 扩展。

  • 参数
    • name:扩展的名称。
    • sources:C++ 源文件的路径列表。
    • 其他参数将传递给 setuptools.Extension 构造函数。

  • 示例

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension


setup(
    name='extension',
    ext_modules=[
        CppExtension(
            name='extension',
            sources=['extension.cpp'],
            extra_compile_args=['-g']
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

(二)torch.utils.cpp_extension.CUDAExtension

用于创建一个 setuptools.Extension,用于构建 CUDA/C++ 扩展。它会自动包含 CUDA 的相关路径和库。

  • 参数
    • name:扩展的名称。
    • sources:C++ 和 CUDA 源文件的路径列表。
    • 其他参数将传递给 setuptools.Extension 构造函数。

  • 示例

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
    name='cuda_extension',
    ext_modules=[
        CUDAExtension(
            name='cuda_extension',
            sources=['extension.cpp', 'extension_kernel.cu'],
            extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

(三)torch.utils.cpp_extension.BuildExtension

自定义 setuptools 构建扩展的类。它负责传递编译器标志,并支持混合 C++/CUDA 编译。

(四)torch.utils.cpp_extension.load

即时加载 PyTorch C++ 扩展(JIT)。它会编译源代码并将其作为模块加载到当前 Python 进程中。

  • 参数
    • name:扩展的名称。
    • sources:C++ 源文件的路径列表。
    • extra_cflags:传递给编译器的额外标志。
    • extra_cuda_cflags:传递给 nvcc 的额外标志。
    • extra_ldflags:传递给链接器的额外标志。
    • build_directory:构建工作区的路径。
    • verbose:是否开启详细日志。
    • with_cuda:是否包含 CUDA 支持。
    • is_python_module:是否作为 Python 模块加载。

  • 示例

from torch.utils.cpp_extension import load


module = load(
    name='extension',
    sources=['extension.cpp', 'extension_kernel.cu'],
    extra_cflags=['-O2'],
    verbose=True
)

(五)torch.utils.cpp_extension.load_inline

从字符串源即时加载 PyTorch C++ 扩展(JIT)。它与 load 类似,但源代码以字符串形式提供。

  • 参数
    • name:扩展的名称。
    • cpp_sources:C++ 源代码字符串或字符串列表。
    • cuda_sources:CUDA 源代码字符串或字符串列表。
    • functions:需要生成绑定的函数名称列表或字典。
    • 其他参数与 load 类似。

  • 示例

from torch.utils.cpp_extension import load_inline


source = '''
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
    return x.sin() + y.sin();
}
'''


module = load_inline(
    name='inline_extension',
    cpp_sources=[source],
    functions=['sin_add']
)

(六)辅助工具

  • torch.utils.cpp_extension.include_paths(cuda=False):获取构建 C++ 或 CUDA 扩展所需的包含路径。
  • torch.utils.cpp_extension.check_compiler_abi_compatibility(compiler):验证编译器是否与 PyTorch 兼容。
  • torch.utils.cpp_extension.verify_ninja_availability():检查系统上是否有 ninja 构建系统。

三、开发和构建 C++ 扩展的步骤

(一)准备源代码

编写 C++ 或 CUDA 源代码,实现自定义的功能或操作。

(二)编写 setup 脚本

创建一个 setup.py 脚本,使用 CppExtensionCUDAExtension 定义扩展,并使用 BuildExtension 构建扩展。

(三)构建和安装扩展

运行以下命令构建和安装扩展:

python setup.py install

(四)即时加载扩展

使用 loadload_inline 函数即时加载扩展,无需编写 setup.py 脚本。

四、总结

通过本教程,我们详细介绍了 PyTorch C++ 扩展的开发工具和流程。利用这些工具,开发者可以轻松地将自定义的 C++ 或 CUDA 代码集成到 PyTorch 项目中,从而提升模型性能或复用现有代码。掌握 C++ 扩展的开发方法,能够帮助您在深度学习项目中更好地优化性能和利用硬件资源。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号