PyTorch torch.onnx

2025-06-25 15:30 更新

一、什么是 ONNX?

ONNX(Open Neural Network Exchange)是一种开放的神经网络交换格式,旨在促进不同深度学习框架之间的模型互操作性。通过将 PyTorch 模型导出为 ONNX 格式,我们可以在其他支持 ONNX 的框架和工具中使用这些模型,如 Caffe2、Microsoft ONNX Runtime 等。这对于模型的部署和优化具有重要意义。

二、PyTorch 模型导出 ONNX 的基本流程

(一)示例:将预训练的 AlexNet 导出到 ONNX

  1. 导入必要的库

import torch
import torchvision

  1. 准备输入和模型

## 创建一个虚拟输入,形状为 (10, 3, 224, 224),并将其移动到 GPU 上
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')


## 加载预训练的 AlexNet 模型,并将其移动到 GPU 上
model = torchvision.models.alexnet(pretrained=True).cuda()


## 将模型设置为评估模式(非训练模式)
model.eval()

  1. 定义输入和输出名称

## 为模型的输入和参数指定名称,以提高模型图的可读性
input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(16)]
output_names = ["output1"]

  1. 导出模型到 ONNX

## 将模型导出到 ONNX 文件 "alexnet.onnx"
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

(二)验证导出的 ONNX 模型

  1. 使用 ONNX 库验证模型

import onnx


## 加载导出的 ONNX 模型
model = onnx.load("alexnet.onnx")


## 检查模型的 IR 是否有效
onnx.checker.check_model(model)


## 打印模型图的可读表示形式
print(onnx.helper.printable_graph(model.graph))

  1. 使用 ONNX Runtime 运行模型

import onnxruntime as ort
import numpy as np


## 创建 ONNX Runtime 推理会话
ort_session = ort.InferenceSession('alexnet.onnx')


## 准备输入数据
input_data = np.random.randn(10, 3, 224, 224).astype(np.float32)


## 运行模型并获取输出
outputs = ort_session.run(None, {'actual_input_1': input_data})


## 打印输出结果
print(outputs[0])

三、跟踪与脚本编写

(一)基于跟踪的导出器

基于跟踪的导出器通过执行一次模型并导出在此运行期间实际执行的运算符来操作。这意味着如果模型是动态的(例如,根据输入数据更改行为),导出可能不准确。同样,跟踪可能仅对特定的输入大小有效。我们建议检查模型跟踪并确保所跟踪的运算符看起来合理。

例如:

import torch


## 定义一个简单的模型类
class LoopModel(torch.nn.Module):
    def forward(self, x, y):
        for i in range(y):
            x = x + i
        return x


model = LoopModel()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)


## 使用基于跟踪的导出器导出模型
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True)

(二)基于脚本的导出器

基于脚本的导出器表示要导出的模型是 ScriptModule。ScriptModule 是 TorchScript 中的核心数据结构,TorchScript 是 Python 语言的子集,可用于从 PyTorch 代码创建可序列化和可优化的模型。

例如:

@torch.jit.script
def loop(x, y):
    for i in range(int(y)):
        x = x + i
    return x


class LoopModel2(torch.nn.Module):
    def forward(self, x, y):
        return loop(x, y)


model = LoopModel2()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)


## 使用基于脚本的导出器导出模型
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True, input_names=['input_data', 'loop_range'])

四、局限性和常见问题

  1. 张量就地索引分配不支持 :目前导出中不支持张量就地索引分配,如 data[index] = new_data。可以通过使用 scatter_ 运算符来解决此类问题。
  2. ONNX 中没有张量列表的概念 :这使得导出消耗或产生张量列表的运算符变得困难,尤其是在导出时不知道张量列表的长度的情况下。
  3. 输入大小固定问题 :如果模型应接受动态形状的输入,可以在导出 API 中使用参数 dynamic_axes 来指定动态轴。
  4. 隐式标量数据类型转换问题 :ONNX 不支持隐式标量数据类型转换,但导出器会尝试处理该部分。对于无法自动处理的情况,需要手动提供数据类型信息。

五、总结

通过本教程,我们详细介绍了如何将 PyTorch 模型导出为 ONNX 格式,包括基本流程、跟踪与脚本编写的区别,以及一些局限性和常见问题的解决方案。将模型导出为 ONNX 格式可以提高模型的互操作性和部署灵活性,使我们能够在各种支持 ONNX 的框架和工具中使用这些模型。

在实际应用中,根据模型的特点和需求,选择合适的导出方式(基于跟踪或基于脚本),并注意处理可能遇到的局限性和问题,可以确保模型成功导出并能够在目标环境中正常运行。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号