行莫
行莫
发布于 2025-12-08 / 9 阅读
0
0

ONNX Runtime 跨平台模型推理引擎

ONNX Runtime 跨平台模型推理引擎

引言

想象一下,你要在不同平台上部署机器学习模型:

  • 方法1:为每个平台(Python、Java、C#、移动端)分别实现模型(工作量大、维护困难)
  • 方法2:使用统一的格式和运行时,一次训练,到处运行

ONNX Runtime 就像方法2——它是一个跨平台的模型推理引擎,让你可以在任何平台上运行机器学习模型,无需关心底层实现细节。

本文将用生动的类比、详细的代码示例和实际应用场景,带你深入了解 ONNX Runtime 的强大功能和如何使用它来部署机器学习模型。


第一部分:什么是 ONNX Runtime?

ONNX Runtime 的直观理解

ONNX Runtime 是一个高性能的跨平台推理引擎,用于运行 ONNX(Open Neural Network Exchange)格式的机器学习模型。

类比理解:

  • 就像Java 虚拟机(JVM):一次编写,到处运行
  • 就像Docker 容器:打包好的应用可以在任何环境运行
  • 就像通用翻译器:将模型转换为通用格式,任何平台都能理解
  • 就像跨平台播放器:可以在任何设备上播放标准格式的视频

ONNX Runtime 的核心特点

  1. 跨平台支持:Windows、Linux、macOS、Android、iOS 等
  2. 多语言支持:Python、Java、C#、C++、JavaScript 等
  3. 高性能:优化的推理引擎,支持 CPU、GPU、TPU
  4. 生产就绪:被 Microsoft、Facebook、Amazon 等公司广泛使用
  5. 易于集成:简单的 API,易于集成到现有项目

类比理解:

  • 就像通用适配器:连接不同平台和模型
  • 就像高性能引擎:优化的执行效率
  • 就像标准接口:统一的 API 设计

什么是 ONNX?

ONNX(Open Neural Network Exchange) 是一个开放的模型格式标准:

  • 开放标准:由 Microsoft、Facebook、Amazon 等公司共同制定
  • 格式统一:将不同框架的模型转换为统一格式
  • 互操作性:可以在不同框架和平台间转换

支持的框架:

  • PyTorch → ONNX
  • TensorFlow → ONNX
  • Keras → ONNX
  • Scikit-learn → ONNX
  • 等等...

类比理解:

  • 就像PDF 格式:任何设备都能打开
  • 就像MP3 格式:任何播放器都能播放
  • 就像JSON 格式:任何语言都能解析

第二部分:为什么需要 ONNX Runtime?

问题场景

场景1:多平台部署

  • 训练模型:Python + PyTorch
  • 部署环境:Java 后端、C# 桌面应用、移动端
  • 问题:如何在不同平台运行同一个模型?

场景2:性能优化

  • 训练框架:PyTorch(研究友好)
  • 生产环境:需要更高性能
  • 问题:如何在不改变模型的情况下提升性能?

场景3:框架迁移

  • 旧系统:TensorFlow 1.x
  • 新系统:需要迁移到其他框架
  • 问题:如何平滑迁移而不重写代码?

ONNX Runtime 的解决方案

1. 统一格式

  • 将模型转换为 ONNX 格式
  • 一次转换,到处运行

2. 高性能推理

  • 优化的执行引擎
  • 支持硬件加速(GPU、TPU)

3. 跨平台支持

  • 支持多种编程语言
  • 支持多种操作系统

类比理解:

  • 就像集装箱:标准化的容器,可以在任何运输工具上使用
  • 就像USB 接口:标准化的接口,任何设备都能连接
  • 就像HTTP 协议:标准化的通信协议,任何系统都能理解

第三部分:ONNX Runtime 的安装与配置

Python 安装

使用 pip 安装:

# CPU 版本
pip install onnxruntime

# GPU 版本(需要 CUDA)
pip install onnxruntime-gpu

# 验证安装
python -c "import onnxruntime as ort; print(ort.__version__)"

使用 conda 安装:

conda install -c conda-forge onnxruntime

Java 安装

Maven 依赖:

<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.16.0</version>
</dependency>

Gradle 依赖:

dependencies {
    implementation 'com.microsoft.onnxruntime:onnxruntime:1.16.0'
}

C# 安装

NuGet 包:

Install-Package Microsoft.ML.OnnxRuntime

Node.js 安装

npm install onnxruntime-node

第四部分:模型转换:从训练框架到 ONNX

PyTorch 模型转换

示例:转换一个简单的分类模型

import torch
import torch.onnx
import torchvision.models as models

# 1. 加载预训练模型
model = models.resnet50(pretrained=True)
model.eval()

# 2. 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 3. 导出为 ONNX
torch.onnx.export(
    model,                          # 模型
    dummy_input,                    # 示例输入
    "resnet50.onnx",                # 输出文件路径
    input_names=['input'],          # 输入名称
    output_names=['output'],        # 输出名称
    dynamic_axes={
        'input': {0: 'batch_size'}, # 动态批次大小
        'output': {0: 'batch_size'}
    },
    opset_version=11                # ONNX 操作集版本
)

print("模型已成功导出为 ONNX 格式")

转换 Hugging Face 模型:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers.onnx import export, FeaturesManager
import torch

# 1. 加载模型和分词器
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# 2. 获取模型特征
feature = FeaturesManager.get_supported_features(model.config.model_type)["default"]

# 3. 导出为 ONNX
onnx_path = "distilbert.onnx"
export(
    tokenizer,
    model,
    feature,
    opset=12,
    output=onnx_path
)

print(f"模型已导出到: {onnx_path}")

TensorFlow 模型转换

使用 tf2onnx:

import tensorflow as tf
import tf2onnx

# 1. 加载 TensorFlow 模型
model = tf.keras.models.load_model("my_model.h5")

# 2. 转换为 ONNX
onnx_model, _ = tf2onnx.convert.from_keras(model, output_path="model.onnx")

print("模型转换完成")

Scikit-learn 模型转换

使用 skl2onnx:

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.ensemble import RandomForestClassifier
import numpy as np

# 1. 训练模型
X_train = np.random.rand(100, 4)
y_train = np.random.randint(0, 2, 100)
model = RandomForestClassifier()
model.fit(X_train, y_train)

# 2. 定义输入类型
initial_type = [('float_input', FloatTensorType([None, 4]))]

# 3. 转换为 ONNX
onnx_model = convert_sklearn(model, initial_types=initial_type)

# 4. 保存模型
with open("rf_model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

第五部分:Python 中使用 ONNX Runtime

基础使用

示例1:图像分类

import onnxruntime as ort
import numpy as np
from PIL import Image

# 1. 创建推理会话
session = ort.InferenceSession("resnet50.onnx")

# 2. 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_shape = session.get_inputs()[0].shape

print(f"输入名称: {input_name}")
print(f"输入形状: {input_shape}")
print(f"输出名称: {output_name}")

# 3. 预处理图像
image = Image.open("cat.jpg")
image = image.resize((224, 224))
image_array = np.array(image).astype(np.float32)
image_array = image_array.transpose(2, 0, 1)  # HWC -> CHW
image_array = np.expand_dims(image_array, axis=0)  # 添加批次维度
image_array = image_array / 255.0  # 归一化

# 4. 运行推理
outputs = session.run([output_name], {input_name: image_array})

# 5. 处理输出
predictions = outputs[0]
predicted_class = np.argmax(predictions[0])
print(f"预测类别: {predicted_class}")
print(f"置信度: {predictions[0][predicted_class]:.4f}")

示例2:文本分类

import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer

# 1. 加载分词器
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

# 2. 创建推理会话
session = ort.InferenceSession("distilbert.onnx")

# 3. 准备输入
text = "I love this product!"
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True)

# 4. 运行推理
outputs = session.run(
    None,
    {
        "input_ids": inputs["input_ids"].astype(np.int64),
        "attention_mask": inputs["attention_mask"].astype(np.int64)
    }
)

# 5. 处理输出
logits = outputs[0]
predictions = np.argmax(logits, axis=-1)
print(f"预测结果: {predictions[0]}")

高级特性

1. 使用 GPU 加速

import onnxruntime as ort

# 创建会话选项
options = ort.SessionOptions()

# 设置执行提供者(GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

# 创建会话
session = ort.InferenceSession("model.onnx", options, providers=providers)

print(f"使用的执行提供者: {session.get_providers()}")

2. 批量推理

import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("model.onnx")

# 准备批量输入
batch_size = 8
inputs = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)

# 批量推理
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: inputs})

print(f"批量推理完成,处理了 {batch_size} 个样本")

3. 动态输入形状

import onnxruntime as ort
import numpy as np

# 创建会话选项
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession("model.onnx", options)

# 使用不同大小的输入
for batch_size in [1, 4, 8, 16]:
    inputs = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
    outputs = session.run(None, {session.get_inputs()[0].name: inputs})
    print(f"批次大小 {batch_size}: 推理成功")

第六部分:Java 中使用 ONNX Runtime

基础使用

示例:图像分类

import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;

public class ONNXRuntimeExample {
    
    private OrtEnvironment env;
    private OrtSession session;
    private final int INPUT_SIZE = 224;
    
    public ONNXRuntimeExample(String modelPath) throws OrtException {
        // 1. 创建 ONNX Runtime 环境
        env = OrtEnvironment.getEnvironment();
        
        // 2. 创建会话选项
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
        
        // 3. 加载模型
        session = env.createSession(modelPath, opts);
        
        System.out.println("模型加载成功");
        System.out.println("输入节点: " + session.getInputNames());
        System.out.println("输出节点: " + session.getOutputNames());
    }
    
    /**
     * 预处理图像
     */
    private float[][][][] preprocessImage(BufferedImage image) {
        // 调整大小
        BufferedImage resized = resizeImage(image, INPUT_SIZE, INPUT_SIZE);
        
        int width = resized.getWidth();
        int height = resized.getHeight();
        float[][][][] tensor = new float[1][3][height][width];
        
        // 转换为张量并归一化
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = resized.getRGB(x, y);
                int r = (rgb >> 16) & 0xFF;
                int g = (rgb >> 8) & 0xFF;
                int b = rgb & 0xFF;
                
                tensor[0][0][y][x] = r / 255.0f; // R
                tensor[0][1][y][x] = g / 255.0f; // G
                tensor[0][2][y][x] = b / 255.0f; // B
            }
        }
        
        return tensor;
    }
    
    /**
     * 调整图像大小
     */
    private BufferedImage resizeImage(BufferedImage original, int width, int height) {
        BufferedImage resized = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        java.awt.Graphics2D g = resized.createGraphics();
        g.setRenderingHint(java.awt.RenderingHints.KEY_INTERPOLATION,
                          java.awt.RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.drawImage(original, 0, 0, width, height, null);
        g.dispose();
        return resized;
    }
    
    /**
     * 展平数组
     */
    private float[] flattenArray(float[][][][] array) {
        int totalSize = array.length * array[0].length * 
                       array[0][0].length * array[0][0][0].length;
        float[] flattened = new float[totalSize];
        int index = 0;
        
        for (float[][][] a : array) {
            for (float[][] b : a) {
                for (float[] c : b) {
                    for (float d : c) {
                        flattened[index++] = d;
                    }
                }
            }
        }
        
        return flattened;
    }
    
    /**
     * 运行推理
     */
    public float[] predict(BufferedImage image) throws OrtException {
        // 1. 预处理
        float[][][][] inputTensor = preprocessImage(image);
        
        // 2. 创建 ONNX Tensor
        long[] shape = {1, 3, INPUT_SIZE, INPUT_SIZE};
        OnnxTensor tensor = OnnxTensor.createTensor(
            env, 
            FloatBuffer.wrap(flattenArray(inputTensor)), 
            shape
        );
        
        // 3. 运行推理
        String inputName = session.getInputNames().iterator().next();
        OrtSession.Result outputs = session.run(
            Collections.singletonMap(inputName, tensor)
        );
        
        // 4. 获取输出
        OnnxValue outputValue = outputs.get(0);
        float[][] output = (float[][]) outputValue.getValue();
        
        // 5. 清理资源
        tensor.close();
        outputs.close();
        
        return output[0];
    }
    
    /**
     * 关闭资源
     */
    public void close() throws OrtException {
        if (session != null) {
            session.close();
        }
    }
    
    public static void main(String[] args) {
        try {
            // 1. 初始化模型
            ONNXRuntimeExample model = new ONNXRuntimeExample("resnet50.onnx");
            
            // 2. 加载图像
            BufferedImage image = ImageIO.read(new File("cat.jpg"));
            
            // 3. 运行推理
            float[] predictions = model.predict(image);
            
            // 4. 找到最大概率的类别
            int maxIndex = 0;
            float maxValue = predictions[0];
            for (int i = 1; i < predictions.length; i++) {
                if (predictions[i] > maxValue) {
                    maxValue = predictions[i];
                    maxIndex = i;
                }
            }
            
            System.out.println("预测类别: " + maxIndex);
            System.out.println("置信度: " + maxValue);
            
            // 5. 清理资源
            model.close();
            
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

高级特性

1. 使用 GPU

import ai.onnxruntime.*;

OrtEnvironment env = OrtEnvironment.getEnvironment();

// 创建会话选项
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();

// 添加 CUDA 执行提供者
try {
    opts.addCUDA(0); // 使用第一个 GPU
    System.out.println("使用 GPU 加速");
} catch (OrtException e) {
    System.out.println("GPU 不可用,使用 CPU");
}

OrtSession session = env.createSession("model.onnx", opts);

2. 批量推理

// 准备批量输入
int batchSize = 8;
float[][][][] batchInput = new float[batchSize][3][224][224];
// ... 填充数据 ...

// 创建批量张量
long[] shape = {batchSize, 3, 224, 224};
OnnxTensor tensor = OnnxTensor.createTensor(
    env,
    FloatBuffer.wrap(flattenBatchArray(batchInput)),
    shape
);

// 批量推理
OrtSession.Result outputs = session.run(
    Collections.singletonMap("input", tensor)
);

第七部分:实际应用案例

案例1:图像背景移除(BEN2 模型)

import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;

public class BEN2BackgroundRemoval {
    
    private OrtEnvironment env;
    private OrtSession session;
    
    public BEN2BackgroundRemoval(String modelPath) throws OrtException {
        env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        session = env.createSession(modelPath, opts);
    }
    
    public BufferedImage removeBackground(BufferedImage inputImage) throws OrtException {
        // 1. 预处理图像
        float[][][][] inputTensor = preprocessImage(inputImage);
        
        // 2. 创建输入张量
        long[] shape = {1, 3, 1024, 1024};
        OnnxTensor tensor = OnnxTensor.createTensor(
            env,
            FloatBuffer.wrap(flattenArray(inputTensor)),
            shape
        );
        
        // 3. 运行推理
        OrtSession.Result outputs = session.run(
            Collections.singletonMap("input", tensor)
        );
        
        // 4. 后处理输出
        float[][][][] output = (float[][][][]) outputs.get(0).getValue();
        BufferedImage result = postprocessOutput(output, 
                                                inputImage.getWidth(), 
                                                inputImage.getHeight());
        
        // 5. 清理资源
        tensor.close();
        outputs.close();
        
        return result;
    }
    
    // ... 预处理和后处理方法 ...
    
    public static void main(String[] args) {
        try {
            BEN2BackgroundRemoval model = new BEN2BackgroundRemoval("ben2.onnx");
            BufferedImage input = ImageIO.read(new File("input.jpg"));
            BufferedImage foreground = model.removeBackground(input);
            ImageIO.write(foreground, "PNG", new File("foreground.png"));
            model.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

案例2:文本情感分析

import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer

class SentimentAnalyzer:
    def __init__(self, model_path, tokenizer_name):
        # 加载分词器
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # 创建推理会话
        self.session = ort.InferenceSession(model_path)
        
    def analyze(self, text):
        # 1. 分词
        inputs = self.tokenizer(
            text,
            return_tensors="np",
            padding=True,
            truncation=True,
            max_length=128
        )
        
        # 2. 运行推理
        outputs = self.session.run(
            None,
            {
                "input_ids": inputs["input_ids"].astype(np.int64),
                "attention_mask": inputs["attention_mask"].astype(np.int64)
            }
        )
        
        # 3. 处理输出
        logits = outputs[0]
        probabilities = self.softmax(logits[0])
        
        return {
            "positive": float(probabilities[1]),
            "negative": float(probabilities[0])
        }
    
    @staticmethod
    def softmax(x):
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum()

# 使用示例
analyzer = SentimentAnalyzer(
    "sentiment_model.onnx",
    "distilbert-base-uncased-finetuned-sst-2-english"
)

result = analyzer.analyze("I love this product!")
print(f"正面情感: {result['positive']:.2%}")
print(f"负面情感: {result['negative']:.2%}")

案例3:实时目标检测

import onnxruntime as ort
import cv2
import numpy as np

class ObjectDetector:
    def __init__(self, model_path):
        self.session = ort.InferenceSession(model_path)
        self.input_name = self.session.get_inputs()[0].name
        
    def detect(self, image):
        # 1. 预处理
        resized = cv2.resize(image, (640, 640))
        input_tensor = resized.astype(np.float32) / 255.0
        input_tensor = np.transpose(input_tensor, (2, 0, 1))
        input_tensor = np.expand_dims(input_tensor, axis=0)
        
        # 2. 推理
        outputs = self.session.run(None, {self.input_name: input_tensor})
        
        # 3. 后处理(NMS、解码等)
        boxes, scores, classes = self.postprocess(outputs[0], image.shape)
        
        return boxes, scores, classes
    
    def postprocess(self, outputs, image_shape):
        # 实现 NMS 和坐标转换
        # ...
        return boxes, scores, classes

# 实时检测
detector = ObjectDetector("yolov5.onnx")
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    boxes, scores, classes = detector.detect(frame)
    
    # 绘制检测结果
    for box, score, cls in zip(boxes, scores, classes):
        if score > 0.5:
            cv2.rectangle(frame, box[0], box[1], (0, 255, 0), 2)
            cv2.putText(frame, f"{cls}: {score:.2f}", 
                       box[0], cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
    
    cv2.imshow("Detection", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

第八部分:性能优化技巧

1. 图优化

Python:

import onnxruntime as ort

# 创建会话选项
options = ort.SessionOptions()

# 启用所有图优化
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# 使用优化的执行提供者
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

session = ort.InferenceSession("model.onnx", options, providers=providers)

Java:

OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
OrtSession session = env.createSession("model.onnx", opts);

2. 线程配置

options = ort.SessionOptions()
options.intra_op_num_threads = 4  # 操作内线程数
options.inter_op_num_threads = 2   # 操作间线程数

3. 内存优化

# 使用内存映射
options = ort.SessionOptions()
options.enable_mem_pattern = True
options.enable_cpu_mem_arena = True

4. 量化模型

使用量化工具:

from onnxruntime.quantization import quantize_dynamic, QuantType

# 动态量化
quantize_dynamic(
    "model.onnx",
    "model_quantized.onnx",
    weight_type=QuantType.QUInt8
)

第九部分:ONNX Runtime vs 其他方案

对比 TensorFlow Serving

特性ONNX RuntimeTensorFlow Serving
跨框架支持✅ 支持多种框架❌ 仅支持 TensorFlow
跨平台✅ 广泛支持⚠️ 主要 Linux
性能⭐⭐⭐⭐⭐⭐⭐⭐⭐
易用性⭐⭐⭐⭐⭐⭐⭐⭐
部署复杂度

对比 PyTorch JIT

特性ONNX RuntimePyTorch JIT
跨语言支持✅ 多语言❌ 仅 Python
跨平台✅ 广泛支持⚠️ 有限支持
性能⭐⭐⭐⭐⭐⭐⭐⭐⭐
模型兼容性✅ 多种框架❌ 仅 PyTorch

对比原生框架

特性ONNX Runtime原生框架
部署灵活性✅ 一次转换,到处运行❌ 需要为每个平台适配
性能⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
维护成本
学习曲线平缓陡峭

第十部分:最佳实践

1. 模型转换最佳实践

检查模型兼容性:

import onnx
from onnx import checker

# 加载模型
model = onnx.load("model.onnx")

# 检查模型
checker.check_model(model)
print("模型检查通过")

# 查看模型信息
print(f"ONNX 版本: {model.opset_import[0].version}")
print(f"输入: {[input.name for input in model.graph.input]}")
print(f"输出: {[output.name for output in model.graph.output]}")

测试转换后的模型:

import torch
import onnxruntime as ort
import numpy as np

# 1. 加载原始模型
original_model = torch.load("model.pth")
original_model.eval()

# 2. 加载 ONNX 模型
onnx_session = ort.InferenceSession("model.onnx")

# 3. 创建测试输入
test_input = torch.randn(1, 3, 224, 224)

# 4. 原始模型输出
with torch.no_grad():
    original_output = original_model(test_input).numpy()

# 5. ONNX 模型输出
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: test_input.numpy()})[0]

# 6. 比较结果
diff = np.abs(original_output - onnx_output).max()
print(f"最大差异: {diff}")
assert diff < 1e-5, "模型转换可能有误"

2. 错误处理

import onnxruntime as ort
import numpy as np

try:
    session = ort.InferenceSession("model.onnx")
    
    # 检查输入形状
    input_shape = session.get_inputs()[0].shape
    print(f"期望输入形状: {input_shape}")
    
    # 准备输入
    if None in input_shape:
        # 动态形状
        batch_size = 1
        input_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
    else:
        # 固定形状
        input_data = np.random.randn(*input_shape).astype(np.float32)
    
    # 运行推理
    outputs = session.run(None, {session.get_inputs()[0].name: input_data})
    
except ort.OrtException as e:
    print(f"ONNX Runtime 错误: {e}")
except Exception as e:
    print(f"其他错误: {e}")

3. 性能监控

import time
import onnxruntime as ort

session = ort.InferenceSession("model.onnx")

# 预热
for _ in range(10):
    session.run(None, {session.get_inputs()[0].name: dummy_input})

# 性能测试
times = []
for _ in range(100):
    start = time.time()
    session.run(None, {session.get_inputs()[0].name: dummy_input})
    times.append(time.time() - start)

print(f"平均推理时间: {np.mean(times)*1000:.2f} ms")
print(f"最小推理时间: {np.min(times)*1000:.2f} ms")
print(f"最大推理时间: {np.max(times)*1000:.2f} ms")
print(f"吞吐量: {1/np.mean(times):.2f} FPS")

第十一部分:常见问题与解决方案

问题1:模型转换失败

原因:

  • 使用了不支持的算子
  • ONNX 版本不兼容
  • 动态形状处理不当

解决方案:

# 1. 检查支持的算子
import torch.onnx.symbolic_registry as registry
print("支持的算子:", list(registry._registry.keys()))

# 2. 使用更新的 opset 版本
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=13)

# 3. 处理不支持的算子
# 可能需要重写模型或使用替代方案

问题2:推理性能不佳

解决方案:

# 1. 启用图优化
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# 2. 使用 GPU
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

# 3. 批量处理
# 批量推理通常比单个推理更高效

# 4. 量化模型
# 减少模型大小和计算量

问题3:内存不足

解决方案:

# 1. 减少批次大小
batch_size = 1  # 而不是 8 或 16

# 2. 使用内存映射
options = ort.SessionOptions()
options.enable_mem_pattern = True

# 3. 及时释放资源
tensor.close()
outputs.close()

问题4:跨平台兼容性问题

解决方案:

# 1. 使用标准数据类型
# 避免使用平台特定的类型

# 2. 测试不同平台
# 在目标平台上测试模型

# 3. 使用固定形状(如果可能)
# 避免动态形状带来的问题

第十二部分:总结

ONNX Runtime 的核心优势

  1. 跨平台:一次转换,到处运行
  2. 高性能:优化的推理引擎
  3. 多语言支持:Python、Java、C#、C++ 等
  4. 生产就绪:被大公司广泛使用
  5. 易于集成:简单的 API

适用场景

适合使用 ONNX Runtime:

  • ✅ 需要跨平台部署
  • ✅ 需要多语言支持
  • ✅ 需要高性能推理
  • ✅ 需要统一模型格式
  • ✅ 生产环境部署

不适合使用 ONNX Runtime:

  • ❌ 需要训练模型(ONNX Runtime 只用于推理)
  • ❌ 需要频繁修改模型结构
  • ❌ 使用了不支持的算子

学习建议

  1. 从简单开始:先转换简单的模型
  2. 测试验证:转换后验证模型正确性
  3. 性能优化:使用图优化和硬件加速
  4. 生产部署:在目标环境测试

类比总结

理解 ONNX Runtime,就像理解为什么使用标准格式:

  • 跨平台:就像 PDF 格式,任何设备都能打开
  • 高性能:就像优化的编译器,执行效率高
  • 统一标准:就像 HTTP 协议,任何系统都能理解
  • 易于集成:就像标准接口,易于连接

掌握 ONNX Runtime,你就能轻松地在任何平台上部署机器学习模型!


参考资料


评论