如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?(导出.如何用.不兼容.模型.运行...)

wufei1232025-03-16python633

如何用onnx runtime运行pytorch导出的模型并解决类型不兼容问题?

利用ONNX Runtime高效运行PyTorch模型

本文将指导您如何使用ONNX Runtime运行经torch.onnx.export导出的PyTorch模型,并重点解决PyTorch张量与ONNX Runtime所需NumPy数组类型不兼容的问题。

首先,我们来看一个PyTorch模型导出示例:

import torch

class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)

torch.onnx.export(
    SumModule(),
    (torch.ones(2, 2),),
    "onnx.pb",
    input_names=["x"],
    output_names=["sum"]
)

这段代码定义了一个简单的PyTorch模型SumModule,并将其导出为名为onnx.pb的ONNX模型文件。

直接使用PyTorch张量作为ONNX Runtime的输入会导致错误,因为ONNX Runtime期望的是NumPy数组。 错误信息通常提示输入类型错误。

为了解决这个问题,我们需要将PyTorch张量转换为NumPy数组。 正确的代码如下:

import onnxruntime
import numpy as np
import torch

ort_session = onnxruntime.InferenceSession("onnx.pb")

# 关键修改:将torch.Tensor转换为np.ndarray
x = np.ones((2, 2), dtype=np.float32)

inputs = {ort_session.get_inputs()[0].name: x}
print(ort_session.run(None, inputs))

这段代码加载onnx.pb文件,创建一个形状为(2, 2),数据类型为float32的NumPy数组作为模型输入。 ort_session.get_inputs()[0].name 获取输入张量的名称,确保输入数据与模型定义匹配。 ort_session.run 函数运行模型并打印输出结果。

更简洁的等效代码:

import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession("onnx.pb")
input_data = np.ones((2, 2)).astype(np.float32)
output_data = sess.run(None, {"x": input_data})[0]
print(output_data)

这段代码功能相同,但更简洁易读。 关键在于使用NumPy数组作为输入。

通过以上方法,您可以成功加载并运行使用torch.onnx.export导出的PyTorch模型。 请确保输入数据的类型和形状与模型的预期输入相匹配。

以上就是如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?的详细内容,更多请关注知识资源分享宝库其它相关文章!

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。