PyTorch 使用自定义 C ++运算符扩展 TorchScript
在深度学习模型开发中,有时需要使用自定义的 C++ 或 CUDA 函数来扩展 PyTorch 的功能,以实现高性能计算或调用第三方库。本文将详细指导您如何使用自定义 C++ 运算符扩展 PyTorch 的 TorchScript,帮助您在模型开发中突破 Python 的性能瓶颈,充分利用 C++ 的高效性和丰富库资源。
一、创建自定义 C++ 运算符
(一)编写运算符实现
首先,我们需要用 C++ 编写自定义运算符的实现。以下是一个将透视变换应用于图像的示例,该运算符利用了 OpenCV 库:
#include <opencv2/opencv.hpp>
#include <torch/script.h>
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(image.size(0), image.size(1), CV_32FC1, image.data<float>());
cv::Mat warp_mat(warp.size(0), warp.size(1), CV_32FC1, warp.data<float>());
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, {8, 8});
torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), {8, 8});
return output.clone();
}
(二)注册运算符
为了使 TorchScript 能够识别和使用自定义运算符,需要将其注册到 TorchScript 运行时中:
static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
二、构建共享库
将自定义运算符构建为共享库,以便在 Python 和 C++ 中加载和使用。
(一)编写 CMakeLists.txt
创建 CMakeLists.txt
文件来定义构建过程:
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(warp_perspective)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
add_library(warp_perspective SHARED op.cpp)
target_compile_features(warp_perspective PRIVATE cxx_range_for)
target_link_libraries(warp_perspective "${TORCH_LIBRARIES}")
target_link_libraries(warp_perspective opencv_core opencv_imgproc)
(二)编译共享库
在终端中运行以下命令来构建共享库:
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
这将在 build
目录下生成 libwarp_perspective.so
共享库。
三、在 Python 中使用自定义运算符
(一)加载共享库
在 Python 中,使用 torch.ops.load_library
函数加载共享库:
import torch
torch.ops.load_library("/path/to/libwarp_perspective.so")
(二)调用自定义运算符
加载共享库后,即可在 Python 中调用自定义运算符:
result = torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
print(result)
四、将自定义运算符用于 TorchScript 跟踪和脚本
(一)跟踪模式
在跟踪模式下,可以将自定义运算符集成到现有的 PyTorch 模型中:
def compute(x, y, z):
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + torch.relu(z)
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)
(二)脚本模式
在脚本模式下,可以使用 @torch.jit.script
装饰器将函数转换为 TorchScript:
torch.ops.load_library("libwarp_perspective.so")
@torch.jit.script
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + z
五、在 C++ 中使用自定义运算符
(一)加载共享库
在 C++ 项目中,需要将共享库与主应用程序链接。在 CMakeLists.txt 中添加链接设置:
target_link_libraries(example_app "${TORCH_LIBRARIES}")
target_link_libraries(example_app -Wl,--no-as-needed warp_perspective)
(二)加载和执行模型
在 C++ 中加载并执行 TorchScript 模型:
#include <torch/script.h>
#include <iostream>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "Usage: example-app <path-to-exported-script-module>\n";
return -1;
}
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::randn({4, 8}));
inputs.push_back(torch::randn({8, 5}));
torch::Tensor output = module->forward(std::move(inputs)).toTensor();
std::cout << output << std::endl;
}
六、总结与展望
通过本文,您已经掌握了如何使用自定义 C++ 运算符扩展 PyTorch 的 TorchScript。这一技术允许您将 C++ 的高性能计算能力与 PyTorch 的动态图和自动微分功能相结合,为复杂的深度学习任务提供强大的支持。未来,您可以进一步探索如何利用这一技术优化模型性能,或者将其他 C++ 库集成到 PyTorch 项目中。编程狮将持续为您带来更多深度学习模型开发和优化的优质教程,助力您的技术成长。
更多建议: