PyTorch torch脚本

2025-06-25 15:07 更新

一、TorchScript 是什么?

TorchScript 是 PyTorch 提供的一种可以从 PyTorch 代码创建可序列化和可优化模型的方法。有了 TorchScript,我们可以将模型从 Python 环境导出,使其能够在脱离 Python 依赖的环境中运行。例如,在独立的 C++程序中使用模型。这对于将模型部署到生产环境非常有用,因为在生产环境中,Python 可能在性能和多线程方面存在局限性。

二、为什么需要 TorchScript?

  • 模型部署需求 :很多时候,我们希望在不支持 Python 或需要高性能的环境中运行模型,比如移动设备、嵌入式系统或某些服务器端服务。TorchScript 可以将模型转换为一种可在这些环境中高效运行的形式。
  • 性能优化 :通过 TorchScript,可以对模型进行优化,提高其运行效率,尤其是在处理大量数据或需要快速响应的场景中。
  • 代码可维护性 :将模型转换为 TorchScript 后,可以与 Python 代码分离,使得模型代码更加简洁、易维护,也便于团队协作和代码管理。

三、创建 TorchScript 代码

(一)脚本编写

  1. 函数脚本编写
    • 使用 @torch.jit.script 装饰器可以将一个 Python 函数转换为 TorchScript 函数。例如:

import torch


@torch.jit.script
def add_tensors(x, y):
    # type: (Tensor, Tensor) -> Tensor
    return x + y


print(type(add_tensors))  # 输出:torch.jit.ScriptFunction

在这里,我们定义了一个简单的函数 add_tensors,它将两个张量相加。@torch.jit.script 装饰器告诉 PyTorch 将这个函数编译为 TorchScript 函数。通过指定函数参数和返回值的类型(如 # type: (Tensor, Tensor) -> Tensor),我们可以使 TorchScript 更好地理解函数的类型信息,从而进行更有效的编译。

  • 要查看编译后的函数代码,可以使用 code 属性:

print(add_tensors.code)

这将输出编译后的 TorchScript 代码,类似于 Python 语法,但它是经过优化的静态类型代码。

  1. 模块脚本编写
    • 对于 nn.Module,可以使用 torch.jit.script 将其转换为 TorchScript 模块。默认情况下,它会编译模块的 forward 方法以及 forward 中调用的任何方法、子模块和函数。例如:

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = nn.Parameter(torch.rand(N, M))
        self.linear = nn.Linear(N, M)


    def forward(self, input):
        output = self.weight.mv(input)
        output = self.linear(output)
        return output


scripted_module = torch.jit.script(MyModule(2, 3))

在这个例子中,我们定义了一个简单的神经网络模块 MyModule,它包含一个可学习的权重参数和一个线性层。torch.jit.script 将其转换为 TorchScript 模块 scripted_module,该模块可以像原始模块一样使用,但在运行时会使用 TorchScript 解释器来执行,具有更好的性能和可部署性。

(二)跟踪(Tracing)

  1. 函数跟踪
    • 使用 torch.jit.trace 可以跟踪一个函数并返回一个可执行文件或 ScriptFunction。例如:

import torch


def multiply_tensors(x, y):
    return 2 * x + y


traced_foo = torch.jit.trace(multiply_tensors, (torch.rand(3), torch.rand(3)))

这里,我们定义了一个函数 multiply_tensors,它将两个张量进行线性组合。torch.jit.trace 会运行该函数,并记录在给定示例输入(torch.rand(3), torch.rand(3))下执行的所有张量操作,生成一个 ScriptFunction 对象 traced_foo,该对象可以脱离 Python 环境运行。

  1. 模块跟踪
    • 对于 nn.Module,可以使用 torch.jit.tracetorch.jit.trace_module 进行跟踪。例如:

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)


    def forward(self, x):
        return self.conv(x)


n = Net()
example_forward_input = torch.rand(1, 1, 3, 3)


## 跟踪模块的 forward 方法
traced_module = torch.jit.trace(n, example_forward_input)

在这个示例中,我们定义了一个简单的卷积神经网络模块 Net,并使用 torch.jit.trace 对其 forward 方法进行跟踪,生成一个 ScriptModule 对象 traced_module,该对象包含跟踪后的 forward 方法,可以在非 Python 环境中使用。

四、保存和加载 TorchScript 模型

(一)保存模型

  1. 使用 torch.jit.save 可以将 TorchScript 模块保存到文件或类似文件的对象中。例如:

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


m = torch.jit.script(MyModule())


## 保存到文件
torch.jit.save(m, 'scriptmodule.pt')


## 保存到 io.BytesIO 缓冲区
buffer = io.BytesIO()
torch.jit.save(m, buffer)

这里,我们将一个简单的模块 MyModule 转换为 TorchScript 模块 m,然后将其保存到文件 scriptmodule.pt 和内存缓冲区 buffer 中。保存后的模型可以被加载到 C++ API 或 Python API 中使用。

(二)加载模型

  1. 使用 torch.jit.load 可以加载先前保存的 ScriptModuleScriptFunction。例如:

import torch
import io


## 从文件加载
loaded_module = torch.jit.load('scriptmodule.pt')


## 从 io.BytesIO 缓冲区加载
buffer = io.BytesIO()
## 假设 buffer 中包含已保存的模型数据
loaded_module_from_buffer = torch.jit.load(buffer)

加载后的模型可以像原始模型一样使用,但在运行时会使用 TorchScript 解释器来执行。

五、混合跟踪和脚本编写

在实际应用中,我们常常会将跟踪和脚本编写结合起来使用,以充分利用两者的优势。例如:

  1. 在脚本中调用跟踪的函数

import torch


def multiply_tensors(x, y):
    return 2 * x + y


traced_foo = torch.jit.trace(multiply_tensors, (torch.rand(3), torch.rand(3)))


@torch.jit.script
def call_traced_function(x):
    return traced_foo(x, x)

在这个例子中,我们先使用跟踪生成了一个跟踪函数 traced_foo,然后在脚本函数 call_traced_function 中调用它。

  1. 在跟踪函数中调用脚本函数

import torch


@torch.jit.script
def add_tensors(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def combine_tensors(x, y, z):
    return add_tensors(x, y) + z


traced_combine = torch.jit.trace(combine_tensors, (torch.rand(3), torch.rand(3), torch.rand(3)))

这里,我们先定义了一个脚本函数 add_tensors,然后在函数 combine_tensors 中调用它,并使用跟踪生成了一个跟踪函数 traced_combine

六、迁移到 PyTorch 1.2 递归脚本 API

PyTorch 1.2 对 TorchScript API 进行了一些更改,主要包括:

  1. torch.jit.script 现在会尝试递归编译遇到的函数、方法和类。这意味着编译过程变得更加全面和自动化。
  2. 推荐使用 torch.jit.script(nn_module_instance) 来创建 ScriptModule,而不是继承自 torch.jit.ScriptModule。这种新的用法更加简单易用。

例如:

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)


    def forward(self, x):
        x = torch.relu(self.conv1(x))
        return torch.relu(self.conv2(x))


my_model = Model()
my_scripted_model = torch.jit.script(my_model)

在这个例子中,我们将一个普通的 nn.Module 子类 Model 转换为 ScriptModule my_scripted_model,而无需继承自 torch.jit.ScriptModule

七、常见问题解答

  1. Q:我想在 GPU 上训练模型并在 CPU 上进行推理。最佳做法是什么?
  2. A:首先将模型从 GPU 转换为 CPU,然后将其保存。例如:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")


traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")


## 后续使用模型时
if use_gpu:
    model = torch.jit.load("gpu.pth")
else:
    model = torch.jit.load("cpu.pth")


model(input)

推荐这样做是因为跟踪器可能会在特定设备上见证张量的创建,因此在保存之前对模型进行转换可确保跟踪器具有正确的设备信息。

  1. Q:如何在 ScriptModule 上存储属性?

  1. **A:有以下几种方法:
    • 使用 nn.Parameter :包装在 nn.Parameter 中的值将像在 nn.Module 上一样工作。
    • 使用 register_buffer :包装在 register_buffer 中的值将像在 nn.Module 上一样工作。
    • 标记为常量:将类成员注释为 Final(或在类定义级别将其添加到名为 __constants__ 的列表中)会将包含的名称标记为常量。
    • 作为可变属性:可以将支持的类型的值添加为可变属性。大多数类型可以被推断,但可能需要指定一些类型。**

例如:

import torch
from typing import List


class MyModule(torch.nn.Module):
    a: torch.jit.Final[int]
    words: List[str]


    def __init__(self):
        super(MyModule, self).__init__()
        self.a = 10  # 类型被推断为 int
        self.words = []  # 类型需要显式指定
        self.register_buffer('buffer', torch.zeros(10))  # 使用 register_buffer


    def forward(self, x):
        return x + self.a


m = torch.jit.script(MyModule())

在这个例子中,我们展示了如何在 ScriptModule 中存储不同类型的属性。

  1. Q:我想跟踪模块的方法,但一直出现此错误:RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

  1. A:此错误通常表示您要跟踪的方法使用模块的参数,并且您正在传递模块的方法而不是模块实例。要跟踪模块上的特定方法,请使用 torch.jit.trace_module 。例如:

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.param = nn.Parameter(torch.rand(10))


    def forward(self, x):
        return x + self.param


    def another_method(self, x):
        return x * self.param


module = MyModule()
example_input = torch.rand(5)


## 跟踪模块的多个方法
traced_module = torch.jit.trace_module(
    module,
    inputs={'forward': example_input, 'another_method': example_input}
)

在这个例子中,我们使用 torch.jit.trace_module 跟踪了 MyModule 的多个方法,避免了将参数作为常量插入的错误。

八、总结

通过本教程,我们详细介绍了 PyTorch 中的 TorchScript,包括它的概念、创建方法、保存和加载方式,以及如何结合跟踪和脚本编写来满足不同的需求。希望这些内容能帮助您更好地理解和使用 TorchScript,将模型有效地部署到各种环境中。

记住,在使用 TorchScript 时,要根据实际情况选择合适的创建方法(脚本编写或跟踪),并注意保存和加载模型的细节。同时,了解如何处理常见问题将有助于您在实际项目中顺利应用 TorchScript。

如果您想进一步深入学习 PyTorch 和 TorchScript,可以访问编程狮(w3cschool.cn)网站,那里有更多的教程和资源供您参考。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号