PyTorch 使用自定义 C ++运算符扩展 TorchScript

2025-06-23 10:15 更新

在深度学习模型开发中,有时需要使用自定义的 C++ 或 CUDA 函数来扩展 PyTorch 的功能,以实现高性能计算或调用第三方库。本文将详细指导您如何使用自定义 C++ 运算符扩展 PyTorch 的 TorchScript,帮助您在模型开发中突破 Python 的性能瓶颈,充分利用 C++ 的高效性和丰富库资源。

一、创建自定义 C++ 运算符

(一)编写运算符实现

首先,我们需要用 C++ 编写自定义运算符的实现。以下是一个将透视变换应用于图像的示例,该运算符利用了 OpenCV 库:

#include <opencv2/opencv.hpp>
#include <torch/script.h>


torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
    cv::Mat image_mat(image.size(0), image.size(1), CV_32FC1, image.data<float>());
    cv::Mat warp_mat(warp.size(0), warp.size(1), CV_32FC1, warp.data<float>());
    cv::Mat output_mat;
    cv::warpPerspective(image_mat, output_mat, warp_mat, {8, 8});
    torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), {8, 8});
    return output.clone();
}

(二)注册运算符

为了使 TorchScript 能够识别和使用自定义运算符,需要将其注册到 TorchScript 运行时中:

static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);

二、构建共享库

将自定义运算符构建为共享库,以便在 Python 和 C++ 中加载和使用。

(一)编写 CMakeLists.txt

创建 CMakeLists.txt 文件来定义构建过程:

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(warp_perspective)


find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)


add_library(warp_perspective SHARED op.cpp)


target_compile_features(warp_perspective PRIVATE cxx_range_for)
target_link_libraries(warp_perspective "${TORCH_LIBRARIES}")
target_link_libraries(warp_perspective opencv_core opencv_imgproc)

(二)编译共享库

在终端中运行以下命令来构建共享库:

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make

这将在 build 目录下生成 libwarp_perspective.so 共享库。

三、在 Python 中使用自定义运算符

(一)加载共享库

在 Python 中,使用 torch.ops.load_library 函数加载共享库:

import torch
torch.ops.load_library("/path/to/libwarp_perspective.so")

(二)调用自定义运算符

加载共享库后,即可在 Python 中调用自定义运算符:

result = torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
print(result)

四、将自定义运算符用于 TorchScript 跟踪和脚本

(一)跟踪模式

在跟踪模式下,可以将自定义运算符集成到现有的 PyTorch 模型中:

def compute(x, y, z):
    x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
    return x.matmul(y) + torch.relu(z)


inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)

(二)脚本模式

在脚本模式下,可以使用 @torch.jit.script 装饰器将函数转换为 TorchScript:

torch.ops.load_library("libwarp_perspective.so")


@torch.jit.script
def compute(x, y):
    if bool(x[0][0] == 42):
        z = 5
    else:
        z = 10
    x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
    return x.matmul(y) + z

五、在 C++ 中使用自定义运算符

(一)加载共享库

在 C++ 项目中,需要将共享库与主应用程序链接。在 CMakeLists.txt 中添加链接设置:

target_link_libraries(example_app "${TORCH_LIBRARIES}")
target_link_libraries(example_app -Wl,--no-as-needed warp_perspective)

(二)加载和执行模型

在 C++ 中加载并执行 TorchScript 模型:

#include <torch/script.h>
#include <iostream>


int main(int argc, const char* argv[]) {
    if (argc != 2) {
        std::cerr << "Usage: example-app <path-to-exported-script-module>\n";
        return -1;
    }


    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::randn({4, 8}));
    inputs.push_back(torch::randn({8, 5}));
    torch::Tensor output = module->forward(std::move(inputs)).toTensor();
    std::cout << output << std::endl;
}

六、总结与展望

通过本文,您已经掌握了如何使用自定义 C++ 运算符扩展 PyTorch 的 TorchScript。这一技术允许您将 C++ 的高性能计算能力与 PyTorch 的动态图和自动微分功能相结合,为复杂的深度学习任务提供强大的支持。未来,您可以进一步探索如何利用这一技术优化模型性能,或者将其他 C++ 库集成到 PyTorch 项目中。编程狮将持续为您带来更多深度学习模型开发和优化的优质教程,助力您的技术成长。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号