PyTorch 使用自定义 C ++类扩展 TorchScript

2025-06-23 10:24 更新

TorchScript 是 PyTorch 的一种中间表示形式,允许开发者将模型及其执行逻辑编译为高效的序列化格式,便于部署和优化。在许多场景下,开发者可能需要将自定义的 C++ 类集成到 TorchScript 中,以利用 C++ 的高性能特性或调用第三方库。本文将详细讲解如何使用自定义 C++ 类扩展 TorchScript,并通过实例演示其在 Python 和 C++ 环境中的应用。

一、定义自定义 C++ 类

(一)类的基本结构

首先,我们需要定义一个继承自 torch::jit::CustomClassHolder 的 C++ 类。这个基类确保了自定义类能够与 PyTorch 的生命周期管理系统兼容。

  1. #include <torch/script.h>
  2. #include <torch/custom_class.h>
  3. #include <string>
  4. #include <vector>
  5. template <class T>
  6. struct Stack : torch::jit::CustomClassHolder {
  7. std::vector<T> stack_;
  8. Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}
  9. void push(T x) {
  10. stack_.push_back(x);
  11. }
  12. T pop() {
  13. auto val = stack_.back();
  14. stack_.pop_back();
  15. return val;
  16. }
  17. c10::intrusive_ptr<Stack> clone() const {
  18. return c10::make_intrusive<Stack>(stack_);
  19. }
  20. void merge(const c10::intrusive_ptr<Stack>& c) {
  21. for (auto& elem : c->stack_) {
  22. push(elem);
  23. }
  24. }
  25. };

注意:c10::intrusive_ptr 是一个智能指针,用于管理对象的生命周期,类似于 std::shared_ptr

(二)注册自定义类

为了使自定义类在 TorchScript 和 Python 中可见,需要使用 torch::jit::class_ 进行注册。

  1. static auto testStack = torch::jit::class_<Stack<std::string>>("Stack")
  2. .def(torch::jit::init<std::vector<std::string>>())
  3. .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
  4. return self->stack_.back();
  5. })
  6. .def("push", &Stack<std::string>::push)
  7. .def("pop", &Stack<std::string>::pop)
  8. .def("clone", &Stack<std::string>::clone)
  9. .def("merge", &Stack<std::string>::merge);

二、构建共享库

将自定义类的实现和注册代码编译为共享库,以便在不同环境中加载和使用。

(一)创建 CMakeLists.txt

  1. cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
  2. project(custom_class)
  3. find_package(Torch REQUIRED)
  4. add_library(custom_class SHARED class.cpp)
  5. set(CMAKE_CXX_STANDARD 14)
  6. target_link_libraries(custom_class "${TORCH_LIBRARIES}")

(二)编译共享库

  1. mkdir build
  2. cd build
  3. cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
  4. make

三、在 Python 和 TorchScript 中使用自定义类

(一)加载共享库

在 Python 中使用 torch.classes.load_library 加载共享库:

  1. import torch
  2. torch.classes.load_library("libcustom_class.so")

(二)实例化和使用自定义类

  1. s = torch.classes.Stack(["foo", "bar"])
  2. s.push("pushed")
  3. assert s.pop() == "pushed"
  4. s2 = s.clone()
  5. s.merge(s2)
  6. for expected in ["bar", "foo", "bar", "foo"]:
  7. assert s.pop() == expected

(三)在 TorchScript 中使用自定义类

  1. Stack = torch.classes.Stack
  2. @torch.jit.script
  3. def do_stacks(s: Stack) -> (Stack, str):
  4. s2 = torch.classes.Stack(["hi", "mom"])
  5. s2.merge(s)
  6. return s2.clone(), s2.pop()
  7. stack, top = do_stacks(torch.classes.Stack(["wow"]))
  8. assert top == "wow"
  9. for expected in ["wow", "mom", "hi"]:
  10. assert stack.pop() == expected

四、在 C++ 中加载和运行 TorchScript 模型

(一)定义模型并保存

  1. import torch
  2. torch.classes.load_library('libcustom_class.so')
  3. class Foo(torch.nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. def forward(self, s: str) -> str:
  7. stack = torch.classes.Stack(["hi", "mom"])
  8. return stack.pop() + s
  9. scripted_foo = torch.jit.script(Foo())
  10. scripted_foo.save('foo.pt')

(二)加载模型并在 C++ 中运行

  1. #include <torch/script.h>
  2. #include <iostream>
  3. int main(int argc, const char* argv[]) {
  4. torch::jit::script::Module module;
  5. try {
  6. module = torch::jit::load("foo.pt");
  7. } catch (const c10::Error& e) {
  8. std::cerr << "error loading the model\n";
  9. return -1;
  10. }
  11. std::vector<c10::IValue> inputs = {"foobarbaz"};
  12. auto output = module.forward(inputs).toString();
  13. std::cout << output->string() << std::endl;
  14. }

(三)构建和运行 C++ 项目

  1. cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
  2. project(infer)
  3. find_package(Torch REQUIRED)
  4. add_subdirectory(custom_class_project)
  5. add_executable(infer infer.cpp)
  6. set(CMAKE_CXX_STANDARD 14)
  7. target_link_libraries(infer "${TORCH_LIBRARIES}")
  8. target_link_libraries(infer -Wl,--no-as-needed custom_class)

在项目目录中运行以下命令:

  1. mkdir build
  2. cd build
  3. cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
  4. make
  5. ./infer

五、总结与拓展

通过本文,您已掌握如何在 PyTorch 中使用自定义 C++ 类扩展 TorchScript。这一技能在需要高性能计算或调用第三方 C++ 库时尤为有用。未来,您可以进一步探索如何将自定义类与深度学习模型结合,以实现更高效的训练和推理流程。编程狮将持续为您提供更多深度学习模型开发和优化的优质教程,助力您的技术成长。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号