PyTorch 使用自定义 C ++类扩展 TorchScript
TorchScript 是 PyTorch 的一种中间表示形式,允许开发者将模型及其执行逻辑编译为高效的序列化格式,便于部署和优化。在许多场景下,开发者可能需要将自定义的 C++ 类集成到 TorchScript 中,以利用 C++ 的高性能特性或调用第三方库。本文将详细讲解如何使用自定义 C++ 类扩展 TorchScript,并通过实例演示其在 Python 和 C++ 环境中的应用。
一、定义自定义 C++ 类
(一)类的基本结构
首先,我们需要定义一个继承自 torch::jit::CustomClassHolder
的 C++ 类。这个基类确保了自定义类能够与 PyTorch 的生命周期管理系统兼容。
#include <torch/script.h>
#include <torch/custom_class.h>
#include <string>
#include <vector>
template <class T>
struct Stack : torch::jit::CustomClassHolder {
std::vector<T> stack_;
Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}
void push(T x) {
stack_.push_back(x);
}
T pop() {
auto val = stack_.back();
stack_.pop_back();
return val;
}
c10::intrusive_ptr<Stack> clone() const {
return c10::make_intrusive<Stack>(stack_);
}
void merge(const c10::intrusive_ptr<Stack>& c) {
for (auto& elem : c->stack_) {
push(elem);
}
}
};
注意:c10::intrusive_ptr
是一个智能指针,用于管理对象的生命周期,类似于 std::shared_ptr
。
(二)注册自定义类
为了使自定义类在 TorchScript 和 Python 中可见,需要使用 torch::jit::class_
进行注册。
static auto testStack = torch::jit::class_<Stack<std::string>>("Stack")
.def(torch::jit::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
return self->stack_.back();
})
.def("push", &Stack<std::string>::push)
.def("pop", &Stack<std::string>::pop)
.def("clone", &Stack<std::string>::clone)
.def("merge", &Stack<std::string>::merge);
二、构建共享库
将自定义类的实现和注册代码编译为共享库,以便在不同环境中加载和使用。
(一)创建 CMakeLists.txt
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_class)
find_package(Torch REQUIRED)
add_library(custom_class SHARED class.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(custom_class "${TORCH_LIBRARIES}")
(二)编译共享库
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
三、在 Python 和 TorchScript 中使用自定义类
(一)加载共享库
在 Python 中使用 torch.classes.load_library
加载共享库:
import torch
torch.classes.load_library("libcustom_class.so")
(二)实例化和使用自定义类
s = torch.classes.Stack(["foo", "bar"])
s.push("pushed")
assert s.pop() == "pushed"
s2 = s.clone()
s.merge(s2)
for expected in ["bar", "foo", "bar", "foo"]:
assert s.pop() == expected
(三)在 TorchScript 中使用自定义类
Stack = torch.classes.Stack
@torch.jit.script
def do_stacks(s: Stack) -> (Stack, str):
s2 = torch.classes.Stack(["hi", "mom"])
s2.merge(s)
return s2.clone(), s2.pop()
stack, top = do_stacks(torch.classes.Stack(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
assert stack.pop() == expected
四、在 C++ 中加载和运行 TorchScript 模型
(一)定义模型并保存
import torch
torch.classes.load_library('libcustom_class.so')
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, s: str) -> str:
stack = torch.classes.Stack(["hi", "mom"])
return stack.pop() + s
scripted_foo = torch.jit.script(Foo())
scripted_foo.save('foo.pt')
(二)加载模型并在 C++ 中运行
#include <torch/script.h>
#include <iostream>
int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
try {
module = torch::jit::load("foo.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::vector<c10::IValue> inputs = {"foobarbaz"};
auto output = module.forward(inputs).toString();
std::cout << output->string() << std::endl;
}
(三)构建和运行 C++ 项目
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(infer)
find_package(Torch REQUIRED)
add_subdirectory(custom_class_project)
add_executable(infer infer.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(infer "${TORCH_LIBRARIES}")
target_link_libraries(infer -Wl,--no-as-needed custom_class)
在项目目录中运行以下命令:
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
./infer
五、总结与拓展
通过本文,您已掌握如何在 PyTorch 中使用自定义 C++ 类扩展 TorchScript。这一技能在需要高性能计算或调用第三方 C++ 库时尤为有用。未来,您可以进一步探索如何将自定义类与深度学习模型结合,以实现更高效的训练和推理流程。编程狮将持续为您提供更多深度学习模型开发和优化的优质教程,助力您的技术成长。
更多建议: