PyTorch 使用自定义 C ++类扩展 TorchScript

2025-06-23 10:24 更新

TorchScript 是 PyTorch 的一种中间表示形式,允许开发者将模型及其执行逻辑编译为高效的序列化格式,便于部署和优化。在许多场景下,开发者可能需要将自定义的 C++ 类集成到 TorchScript 中,以利用 C++ 的高性能特性或调用第三方库。本文将详细讲解如何使用自定义 C++ 类扩展 TorchScript,并通过实例演示其在 Python 和 C++ 环境中的应用。

一、定义自定义 C++ 类

(一)类的基本结构

首先,我们需要定义一个继承自 torch::jit::CustomClassHolder 的 C++ 类。这个基类确保了自定义类能够与 PyTorch 的生命周期管理系统兼容。

#include <torch/script.h>
#include <torch/custom_class.h>
#include <string>
#include <vector>


template <class T>
struct Stack : torch::jit::CustomClassHolder {
    std::vector<T> stack_;


    Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}


    void push(T x) {
        stack_.push_back(x);
    }


    T pop() {
        auto val = stack_.back();
        stack_.pop_back();
        return val;
    }


    c10::intrusive_ptr<Stack> clone() const {
        return c10::make_intrusive<Stack>(stack_);
    }


    void merge(const c10::intrusive_ptr<Stack>& c) {
        for (auto& elem : c->stack_) {
            push(elem);
        }
    }
};

注意:c10::intrusive_ptr 是一个智能指针,用于管理对象的生命周期,类似于 std::shared_ptr

(二)注册自定义类

为了使自定义类在 TorchScript 和 Python 中可见,需要使用 torch::jit::class_ 进行注册。

static auto testStack = torch::jit::class_<Stack<std::string>>("Stack")
    .def(torch::jit::init<std::vector<std::string>>())
    .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
        return self->stack_.back();
    })
    .def("push", &Stack<std::string>::push)
    .def("pop", &Stack<std::string>::pop)
    .def("clone", &Stack<std::string>::clone)
    .def("merge", &Stack<std::string>::merge);

二、构建共享库

将自定义类的实现和注册代码编译为共享库,以便在不同环境中加载和使用。

(一)创建 CMakeLists.txt

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_class)


find_package(Torch REQUIRED)


add_library(custom_class SHARED class.cpp)
set(CMAKE_CXX_STANDARD 14)


target_link_libraries(custom_class "${TORCH_LIBRARIES}")

(二)编译共享库

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make

三、在 Python 和 TorchScript 中使用自定义类

(一)加载共享库

在 Python 中使用 torch.classes.load_library 加载共享库:

import torch
torch.classes.load_library("libcustom_class.so")

(二)实例化和使用自定义类

s = torch.classes.Stack(["foo", "bar"])
s.push("pushed")
assert s.pop() == "pushed"


s2 = s.clone()
s.merge(s2)


for expected in ["bar", "foo", "bar", "foo"]:
    assert s.pop() == expected

(三)在 TorchScript 中使用自定义类

Stack = torch.classes.Stack


@torch.jit.script
def do_stacks(s: Stack) -> (Stack, str):
    s2 = torch.classes.Stack(["hi", "mom"])
    s2.merge(s)
    return s2.clone(), s2.pop()


stack, top = do_stacks(torch.classes.Stack(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
    assert stack.pop() == expected

四、在 C++ 中加载和运行 TorchScript 模型

(一)定义模型并保存

import torch
torch.classes.load_library('libcustom_class.so')


class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()


    def forward(self, s: str) -> str:
        stack = torch.classes.Stack(["hi", "mom"])
        return stack.pop() + s


scripted_foo = torch.jit.script(Foo())
scripted_foo.save('foo.pt')

(二)加载模型并在 C++ 中运行

#include <torch/script.h>
#include <iostream>


int main(int argc, const char* argv[]) {
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("foo.pt");
    } catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }


    std::vector<c10::IValue> inputs = {"foobarbaz"};
    auto output = module.forward(inputs).toString();
    std::cout << output->string() << std::endl;
}

(三)构建和运行 C++ 项目

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(infer)


find_package(Torch REQUIRED)
add_subdirectory(custom_class_project)
add_executable(infer infer.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(infer "${TORCH_LIBRARIES}")
target_link_libraries(infer -Wl,--no-as-needed custom_class)

在项目目录中运行以下命令:

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
./infer

五、总结与拓展

通过本文,您已掌握如何在 PyTorch 中使用自定义 C++ 类扩展 TorchScript。这一技能在需要高性能计算或调用第三方 C++ 库时尤为有用。未来,您可以进一步探索如何将自定义类与深度学习模型结合,以实现更高效的训练和推理流程。编程狮将持续为您提供更多深度学习模型开发和优化的优质教程,助力您的技术成长。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号