行莫
行莫
发布于 2025-12-08 / 8 阅读
0
0

DJL 与引擎无关的深度学习 Java 框架

DJL 与引擎无关的深度学习 Java 框架

引言

想象一下,你要在 Java 项目中集成深度学习模型:

  • 方法1:为每个框架(PyTorch、TensorFlow、MXNet)分别写代码,维护多套代码(工作量大、维护困难)
  • 方法2:使用统一的 API,一套代码支持所有框架,轻松切换和部署

DJL(Deep Java Library) 就像方法2——它是 AWS 开发的统一深度学习接口,让你可以在 Java 中使用 PyTorch、TensorFlow、MXNet、ONNX Runtime、PaddlePaddle 等所有主流深度学习框架,无需关心底层实现细节。

本文将用生动的类比、详细的代码示例和实际应用场景,带你从零开始,深入掌握 DJL 的强大功能,让你能够在 Java 生态系统中轻松进行深度学习开发。


第一部分:什么是 DJL?

DJL 的直观理解

DJL(Deep Java Library) 是 AWS 开发的开源深度学习库,为 Java 开发者提供了统一的 API 来使用各种深度学习框架。

类比理解:

  • 就像统一接口:一套 API 支持所有深度学习框架,就像 USB-C 接口可以连接各种设备
  • 就像适配器模式:将不同框架的差异隐藏起来,就像电源适配器适配不同电压
  • 就像翻译器:将 Java 代码翻译成不同框架的调用,就像同声传译员
  • 就像万能钥匙:一把钥匙打开所有框架的门,就像万能钥匙可以开所有锁

DJL 的核心特点

  1. 框架无关:支持 PyTorch、TensorFlow、MXNet、ONNX Runtime、PaddlePaddle
  2. 统一 API:一套代码,支持所有框架,无需修改代码即可切换框架
  3. 易于使用:Java 风格的 API,易于学习和使用,符合 Java 开发习惯
  4. 高性能:直接调用底层框架,性能优秀,接近原生框架性能
  5. Hugging Face 集成:可以直接加载 Hugging Face 模型,无需转换
  6. 生产就绪:被 AWS 等公司广泛使用,企业级支持
  7. 跨平台:支持 Windows、Linux、macOS 等操作系统
  8. GPU 支持:支持 CUDA、MPS 等 GPU 加速

类比理解:

  • 就像统一遥控器:一个遥控器控制所有电器,不需要为每个电器准备一个遥控器
  • 就像多语言翻译器:可以理解多种语言,自动翻译成目标语言
  • 就像通用接口:标准化的接口设计,任何设备都能连接

为什么选择 DJL?

与其他 Java 深度学习库的对比:

特性DJLDeeplearning4jSmileONNX Runtime
框架支持⭐⭐⭐⭐⭐ 多种框架⭐⭐ 仅自己的实现⭐⭐ 传统 ML⭐⭐⭐⭐ 仅 ONNX
统一 API✅ 是❌ 否❌ 否⚠️ 仅 ONNX
Hugging Face✅ 直接支持❌ 不支持❌ 不支持⚠️ 需转换
模型训练✅ 支持✅ 支持✅ 支持❌ 仅推理
易用性⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
性能⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
社区支持⭐⭐⭐⭐ AWS 支持⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
文档质量⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐

选择 DJL 的理由:

  • 框架无关:一套代码支持所有框架,灵活切换
  • 统一 API:学习成本低,一套 API 掌握所有框架
  • Hugging Face 集成:可以直接加载模型,无需转换
  • 易于集成:Java 风格,易于集成到现有项目
  • AWS 支持:企业级支持,持续更新
  • 生产就绪:被广泛使用,稳定可靠

DJL 的适用场景

适合使用 DJL:

  • ✅ Java 项目需要深度学习功能
  • ✅ 需要加载 Hugging Face 模型
  • ✅ 需要支持多种框架
  • ✅ 需要统一的 API
  • ✅ 企业级应用
  • ✅ 微服务架构中的模型服务

不适合使用 DJL:

  • ❌ 只需要传统机器学习(使用 Smile)
  • ❌ 只需要 ONNX 模型推理(使用 ONNX Runtime)
  • ❌ Python 项目(使用原生框架更合适)
  • ❌ 需要复杂的模型训练(Python 生态更丰富)

第二部分:DJL 的安装与配置

系统要求

Java 版本:

  • Java 8 或更高版本(推荐 Java 11+)

操作系统:

  • Windows 10/11
  • Linux(Ubuntu、CentOS 等)
  • macOS 10.14+

内存:

  • 至少 4GB RAM(推荐 8GB+)
  • GPU 内存根据模型大小而定

Maven 依赖配置

1. 基础依赖(必需):

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.25.0</version>
</dependency>

2. 选择引擎(至少选择一个):

<!-- PyTorch 引擎(最常用,推荐) -->
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.25.0</version>
</dependency>

<!-- TensorFlow 引擎 -->
<dependency>
    <groupId>ai.djl.tensorflow</groupId>
    <artifactId>tensorflow-engine</artifactId>
    <version>0.25.0</version>
</dependency>

<!-- ONNX Runtime 引擎 -->
<dependency>
    <groupId>ai.djl.onnxruntime</groupId>
    <artifactId>onnxruntime-engine</artifactId>
    <version>0.25.0</version>
</dependency>

<!-- MXNet 引擎 -->
<dependency>
    <groupId>ai.djl.mxnet</groupId>
    <artifactId>mxnet-engine</artifactId>
    <version>0.25.0</version>
</dependency>

3. PyTorch 原生库(CPU 版本):

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <version>2.0.1</version>
    <classifier>${os}-x86_64</classifier>
</dependency>

4. PyTorch 原生库(GPU 版本 - CUDA 11.8):

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cu118</artifactId>
    <version>2.0.1</version>
    <classifier>${os}-x86_64</classifier>
</dependency>

5. Hugging Face 集成(可选):

<dependency>
    <groupId>ai.djl.huggingface</groupId>
    <artifactId>tokenizers</artifactId>
    <version>0.25.0</version>
</dependency>

6. 图像处理支持(可选):

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>model-zoo</artifactId>
    <version>0.25.0</version>
</dependency>

完整的 pom.xml 示例:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    
    <groupId>com.example</groupId>
    <artifactId>djl-demo</artifactId>
    <version>1.0.0</version>
    
    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <os.detected.classifier>${os.detected.name}-${os.detected.arch}</os.detected.classifier>
    </properties>
    
    <dependencies>
        <!-- DJL API -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.25.0</version>
        </dependency>
        
        <!-- PyTorch 引擎 -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.25.0</version>
        </dependency>
        
        <!-- PyTorch 原生库(CPU) -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <version>2.0.1</version>
            <classifier>${os.detected.classifier}</classifier>
        </dependency>
        
        <!-- Model Zoo -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.25.0</version>
        </dependency>
        
        <!-- Hugging Face Tokenizers -->
        <dependency>
            <groupId>ai.djl.huggingface</groupId>
            <artifactId>tokenizers</artifactId>
            <version>0.25.0</version>
        </dependency>
    </dependencies>
    
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>11</source>
                    <target>11</target>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

Gradle 依赖配置

build.gradle 示例:

plugins {
    id 'java'
}

repositories {
    mavenCentral()
}

dependencies {
    implementation 'ai.djl:api:0.25.0'
    implementation 'ai.djl.pytorch:pytorch-engine:0.25.0'
    implementation 'ai.djl.pytorch:pytorch-native-cpu:2.0.1:linux-x86_64'
    implementation 'ai.djl:model-zoo:0.25.0'
    implementation 'ai.djl.huggingface:tokenizers:0.25.0'
}

java {
    sourceCompatibility = JavaVersion.VERSION_11
    targetCompatibility = JavaVersion.VERSION_11
}

验证安装

创建测试类验证安装:

import ai.djl.engine.Engine;
import java.util.List;

public class DJLInstallationTest {
    public static void main(String[] args) {
        System.out.println("=== DJL 安装验证 ===");
        
        // 1. 检查 DJL 版本
        System.out.println("DJL 版本: " + Engine.getDjlVersion());
        
        // 2. 获取所有可用引擎
        List<Engine> engines = Engine.getAllEngines();
        System.out.println("\n可用引擎数量: " + engines.size());
        
        for (Engine engine : engines) {
            System.out.println("\n引擎信息:");
            System.out.println("  名称: " + engine.getEngineName());
            System.out.println("  版本: " + engine.getVersion());
            System.out.println("  GPU 支持: " + engine.hasCapability("CUDA"));
        }
        
        // 3. 获取默认引擎
        Engine defaultEngine = Engine.getInstance();
        System.out.println("\n默认引擎: " + defaultEngine.getEngineName());
        
        // 4. 检查 GPU 可用性
        if (defaultEngine.hasCapability("CUDA")) {
            System.out.println("✓ GPU 可用");
        } else {
            System.out.println("✗ GPU 不可用,将使用 CPU");
        }
        
        System.out.println("\n=== 安装验证完成 ===");
    }
}

运行结果示例:

=== DJL 安装验证 ===
DJL 版本: 0.25.0

可用引擎数量: 1

引擎信息:
  名称: PyTorch
  版本: 2.0.1
  GPU 支持: false

默认引擎: PyTorch
✗ GPU 不可用,将使用 CPU

=== 安装验证完成 ===

类比理解:

  • 就像安装软件:添加依赖 → 验证安装 → 开始使用
  • 就像准备工具:拿到工具 → 检查工具 → 开始工作
  • 就像体检:检查各项指标 → 确认健康 → 开始工作

第三部分:DJL 的核心概念

1. Engine(引擎)

引擎是 DJL 对不同深度学习框架的抽象。每个引擎对应一个深度学习框架。

支持的引擎:

  • PyTorch:最常用,支持 Hugging Face 模型,推荐用于新项目
  • TensorFlow:Google 的框架,适合 TensorFlow 模型
  • MXNet:Apache 的框架,轻量级
  • ONNX Runtime:跨平台推理引擎,性能优秀
  • PaddlePaddle:百度的框架,适合中文场景

类比理解:

  • 就像驱动程序:不同的硬件需要不同的驱动,不同的框架需要不同的引擎
  • 就像翻译器:不同的语言需要不同的翻译器,不同的框架需要不同的引擎
  • 就像适配器:不同的接口需要不同的适配器

使用示例:

import ai.djl.engine.Engine;
import java.util.List;

public class EngineExample {
    public static void main(String[] args) {
        // 1. 获取默认引擎
        Engine engine = Engine.getInstance();
        System.out.println("默认引擎: " + engine.getEngineName());
        System.out.println("引擎版本: " + engine.getVersion());
        
        // 2. 获取所有可用引擎
        List<Engine> engines = Engine.getAllEngines();
        System.out.println("\n所有可用引擎:");
        for (Engine e : engines) {
            System.out.println("  - " + e.getEngineName() + " " + e.getVersion());
        }
        
        // 3. 指定使用特定引擎
        Engine pytorch = Engine.getEngine("PyTorch");
        if (pytorch != null) {
            System.out.println("\nPyTorch 引擎可用");
            System.out.println("  GPU 支持: " + pytorch.hasCapability("CUDA"));
        }
        
        // 4. 检查引擎能力
        Engine defaultEngine = Engine.getInstance();
        System.out.println("\n引擎能力:");
        System.out.println("  CUDA: " + defaultEngine.hasCapability("CUDA"));
        System.out.println("  MKL: " + defaultEngine.hasCapability("MKL"));
    }
}

2. Model(模型)

模型是 DJL 对机器学习模型的抽象。模型可以是预训练的,也可以是自己训练的。

模型类型:

  • ZooModel:从 ModelZoo 加载的预训练模型
  • Model:自定义训练的模型

类比理解:

  • 就像容器:装载训练好的模型
  • 就像包装盒:包装不同格式的模型
  • 就像程序:封装了算法逻辑

加载模型示例:

import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.Application;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;

public class ModelExample {
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准
        Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optApplication(Application.CV.IMAGE_CLASSIFICATION)
            .optFilter("dataset", "imagenet")
            .optFilter("layers", "50")
            .optEngine("PyTorch")
            .build();
        
        // 2. 加载模型
        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
            System.out.println("模型加载成功");
            System.out.println("模型名称: " + model.getName());
            System.out.println("模型路径: " + model.getModelPath());
        }
    }
}

3. Translator(转换器)

转换器负责将输入数据转换为模型需要的格式,以及将模型输出转换为 Java 对象。

转换器的作用:

  • 输入转换:将 Java 对象(如 Image、String)转换为 NDArray
  • 输出转换:将 NDArray 转换为 Java 对象(如 Classifications、String)

类比理解:

  • 就像翻译器:将 Java 数据翻译成模型格式
  • 就像适配器:适配不同格式的数据
  • 就像转换器:将一种格式转换为另一种格式

自定义转换器示例:

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;

public class ImageClassificationTranslator implements Translator<Image, Classifications> {
    
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        
        // 1. 调整图像大小
        Image resized = input.resize(224, 224);
        
        // 2. 转换为 NDArray
        NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
        
        // 3. 归一化(ImageNet 标准)
        array = array.div(255.0f);
        array = array.sub(new float[]{0.485f, 0.456f, 0.406f});
        array = array.div(new float[]{0.229f, 0.224f, 0.225f});
        
        // 4. 添加批次维度 [1, 3, 224, 224]
        array = array.expandDims(0);
        
        return new NDList(array);
    }
    
    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray output = list.singletonOrThrow();
        
        // 1. 应用 softmax
        output = output.softmax(1);
        
        // 2. 转换为 Classifications
        return new Classifications(output);
    }
    
    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
}

4. Predictor(预测器)

预测器用于运行模型进行预测。它是模型和转换器的组合。

类比理解:

  • 就像执行器:执行模型推理
  • 就像计算器:输入数据,输出结果
  • 就像预测器:根据输入预测输出

使用示例:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;

public class PredictorExample {
    public static void main(String[] args) throws Exception {
        // 假设已经加载了模型
        ZooModel<Image, Classifications> model = loadModel();
        
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            // 1. 加载图像
            Image image = ImageFactory.getInstance()
                .fromUrl("https://example.com/cat.jpg");
            
            // 2. 进行预测
            Classifications classifications = predictor.predict(image);
            
            // 3. 显示结果
            System.out.println("预测结果:");
            classifications.topK(5).forEach(item -> 
                System.out.println(String.format(
                    "  %s: %.2f%%", 
                    item.getClassName(), 
                    item.getProbability() * 100
                ))
            );
        }
    }
    
    private static ZooModel<Image, Classifications> loadModel() {
        // 加载模型的代码
        return null;
    }
}

5. NDArray(多维数组)

NDArray是 DJL 的核心数据结构,类似于 NumPy 的 ndarray。

NDArray 的特点:

  • 多维数组:支持任意维度的数组
  • 类型支持:支持 float、double、int、long 等类型
  • 操作丰富:支持各种数学运算、形状变换等
  • 内存管理:自动管理内存,避免内存泄漏

类比理解:

  • 就像NumPy 数组:Python 中的多维数组
  • 就像矩阵:数学中的矩阵运算
  • 就像张量:深度学习中的基本数据结构

NDArray 使用示例:

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

public class NDArrayExample {
    public static void main(String[] args) {
        try (NDManager manager = NDManager.newBaseManager()) {
            // 1. 创建数组
            NDArray array = manager.create(new float[]{1, 2, 3, 4, 5, 6});
            System.out.println("原始数组: " + array);
            
            // 2. 改变形状
            NDArray reshaped = array.reshape(2, 3);
            System.out.println("重塑后: " + reshaped);
            
            // 3. 数学运算
            NDArray doubled = array.mul(2);
            System.out.println("乘以2: " + doubled);
            
            // 4. 创建零数组
            NDArray zeros = manager.zeros(new Shape(3, 4));
            System.out.println("零数组:\n" + zeros);
            
            // 5. 创建随机数组
            NDArray random = manager.randomNormal(new Shape(2, 3));
            System.out.println("随机数组:\n" + random);
        }
    }
}

第四部分:DJL 的基础使用

示例1:图像分类(完整示例)

使用预训练模型进行图像分类:

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.nio.file.Paths;

public class ImageClassificationExample {
    public static void main(String[] args) {
        try {
            // 1. 定义模型标准
            Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("dataset", "imagenet")
                .optFilter("layers", "50")
                .optEngine("PyTorch")
                .build();
            
            // 2. 加载模型
            try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
                 Predictor<Image, Classifications> predictor = model.newPredictor()) {
                
                System.out.println("模型加载成功: " + model.getName());
                
                // 3. 加载图像(可以从文件、URL 或字节数组加载)
                Image image = ImageFactory.getInstance()
                    .fromFile(Paths.get("path/to/image.jpg"));
                
                // 或者从 URL 加载
                // Image image = ImageFactory.getInstance()
                //     .fromUrl("https://example.com/image.jpg");
                
                // 4. 进行预测
                long startTime = System.currentTimeMillis();
                Classifications classifications = predictor.predict(image);
                long endTime = System.currentTimeMillis();
                
                // 5. 显示结果
                System.out.println("\n预测结果(Top 5):");
                System.out.println("预测耗时: " + (endTime - startTime) + "ms");
                System.out.println("----------------------------------------");
                
                classifications.topK(5).forEach(item -> 
                    System.out.println(String.format(
                        "  %-30s: %.2f%%", 
                        item.getClassName(), 
                        item.getProbability() * 100
                    ))
                );
            }
        } catch (ModelException e) {
            System.err.println("模型加载失败: " + e.getMessage());
            e.printStackTrace();
        } catch (TranslateException e) {
            System.err.println("预测失败: " + e.getMessage());
            e.printStackTrace();
        } catch (Exception e) {
            System.err.println("发生错误: " + e.getMessage());
            e.printStackTrace();
        }
    }
}

示例2:文本分类(Hugging Face)

使用 Hugging Face 模型进行文本分类:

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
import java.util.Arrays;

public class TextClassificationExample {
    
    // 自定义转换器
    static class TextClassificationTranslator implements Translator<String, Classifications> {
        private HuggingFaceTokenizer tokenizer;
        
        public TextClassificationTranslator(String tokenizerPath) throws Exception {
            tokenizer = HuggingFaceTokenizer.builder()
                .optTokenizerPath(Paths.get(tokenizerPath))
                .optPadding(true)
                .optMaxLength(128)
                .build();
        }
        
        @Override
        public NDList processInput(TranslatorContext ctx, String input) {
            NDManager manager = ctx.getNDManager();
            
            // 1. 分词
            Encoding encoding = tokenizer.encode(input);
            
            // 2. 转换为 NDArray
            long[] inputIds = encoding.getIds();
            long[] attentionMask = encoding.getAttentionMask();
            
            NDArray inputIdsArray = manager.create(inputIds).expandDims(0);
            NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
            
            return new NDList(inputIdsArray, attentionMaskArray);
        }
        
        @Override
        public Classifications processOutput(TranslatorContext ctx, NDList list) {
            NDArray output = list.singletonOrThrow();
            
            // 应用 softmax
            output = output.softmax(1);
            
            // 转换为 Classifications
            return new Classifications(output, 
                Arrays.asList("negative", "positive"));
        }
    }
    
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准
        Criteria<String, Classifications> criteria = Criteria.builder()
            .setTypes(String.class, Classifications.class)
            .optModelPath(Paths.get("models/distilbert-sentiment"))
            .optModelName("distilbert-sentiment")
            .optEngine("PyTorch")
            .optTranslator(new TextClassificationTranslator("tokenizer.json"))
            .build();
        
        // 2. 加载模型
        try (ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
             Predictor<String, Classifications> predictor = model.newPredictor()) {
            
            // 3. 测试文本
            String[] texts = {
                "I love Java programming!",
                "This is terrible.",
                "The weather is nice today.",
                "I hate this product."
            };
            
            // 4. 进行预测
            for (String text : texts) {
                Classifications result = predictor.predict(text);
                System.out.println("文本: " + text);
                System.out.println("情感: " + result.best().getClassName());
                System.out.println("置信度: " + 
                    String.format("%.2f%%", result.best().getProbability() * 100));
                System.out.println("---");
            }
        }
    }
}

示例3:目标检测

使用预训练模型进行目标检测:

import ai.djl.Application;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;

public class ObjectDetectionExample {
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准
        Criteria<Image, DetectedObjects> criteria = Criteria.builder()
            .setTypes(Image.class, DetectedObjects.class)
            .optApplication(Application.CV.OBJECT_DETECTION)
            .optFilter("backbone", "resnet50")
            .optEngine("PyTorch")
            .build();
        
        // 2. 加载模型
        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria);
             Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            
            // 3. 加载图像
            Image image = ImageFactory.getInstance()
                .fromFile(Paths.get("path/to/image.jpg"));
            
            // 4. 进行检测
            DetectedObjects detections = predictor.predict(image);
            
            // 5. 显示结果
            System.out.println("检测到 " + detections.getNumberOfObjects() + " 个对象:");
            System.out.println("----------------------------------------");
            
            detections.items().forEach(item -> {
                System.out.println(String.format(
                    "类别: %s (%.2f%%)",
                    item.getClassName(),
                    item.getProbability() * 100
                ));
                System.out.println(String.format(
                    "位置: [%.0f, %.0f, %.0f, %.0f]",
                    item.getBoundingBox().getX(),
                    item.getBoundingBox().getY(),
                    item.getBoundingBox().getWidth(),
                    item.getBoundingBox().getHeight()
                ));
                System.out.println("---");
            });
        }
    }
}

示例4:批量处理

批量处理多个图像:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class BatchProcessingExample {
    public static void main(String[] args) throws Exception {
        // 假设已经加载了模型
        ZooModel<Image, Classifications> model = loadModel();
        
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            // 1. 准备批量图像路径
            List<Path> imagePaths = Arrays.asList(
                Paths.get("image1.jpg"),
                Paths.get("image2.jpg"),
                Paths.get("image3.jpg")
            );
            
            // 2. 加载所有图像
            List<Image> images = imagePaths.stream()
                .map(path -> {
                    try {
                        return ImageFactory.getInstance().fromFile(path);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                })
                .collect(Collectors.toList());
            
            // 3. 批量预测
            long startTime = System.currentTimeMillis();
            List<Classifications> results = images.stream()
                .map(predictor::predict)
                .collect(Collectors.toList());
            long endTime = System.currentTimeMillis();
            
            // 4. 处理结果
            System.out.println("批量处理完成");
            System.out.println("处理数量: " + images.size());
            System.out.println("总耗时: " + (endTime - startTime) + "ms");
            System.out.println("平均耗时: " + (endTime - startTime) / images.size() + "ms/张");
            System.out.println("----------------------------------------");
            
            for (int i = 0; i < results.size(); i++) {
                Classifications classification = results.get(i);
                System.out.println("图像 " + (i + 1) + ":");
                System.out.println("  预测: " + classification.best().getClassName());
                System.out.println("  置信度: " + 
                    String.format("%.2f%%", classification.best().getProbability() * 100));
                System.out.println("---");
            }
        }
    }
    
    private static ZooModel<Image, Classifications> loadModel() {
        // 加载模型的代码
        return null;
    }
}

第五部分:加载 Hugging Face 模型

方法1:直接加载 Hugging Face 模型

从本地路径加载 Hugging Face 模型:

import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.Application;
import ai.djl.modality.nlp.qa.QAInput;
import java.nio.file.Paths;

public class HuggingFaceModelExample {
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准
        Criteria<QAInput, String> criteria = Criteria.builder()
            .setTypes(QAInput.class, String.class)
            .optModelPath(Paths.get("models/bert-base-uncased"))
            .optModelName("bert-base-uncased")
            .optEngine("PyTorch")
            .optApplication(Application.NLP.QUESTION_ANSWER)
            .build();
        
        // 2. 加载模型
        try (ZooModel<QAInput, String> model = ModelZoo.loadModel(criteria);
             Predictor<QAInput, String> predictor = model.newPredictor()) {
            
            // 3. 准备输入
            QAInput input = new QAInput(
                "What is the capital of France?",
                "Paris is the capital of France. It is a beautiful city."
            );
            
            // 4. 进行预测
            String answer = predictor.predict(input);
            System.out.println("问题: " + input.getQuestion());
            System.out.println("上下文: " + input.getParagraph());
            System.out.println("答案: " + answer);
        }
    }
}

方法2:从 Hugging Face Hub 下载模型

自动从 Hugging Face Hub 下载模型:

import ai.djl.huggingface.translator.TextClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;

public class HuggingFaceHubExample {
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准(会自动从 Hugging Face Hub 下载)
        Criteria<String, Classifications> criteria = Criteria.builder()
            .setTypes(String.class, Classifications.class)
            .optModelName("distilbert-base-uncased-finetuned-sst-2-english")
            .optModelPath(Paths.get("./models"))  // 下载到本地目录
            .optEngine("PyTorch")
            .optTranslatorFactory(new TextClassificationTranslatorFactory())
            .build();
        
        // 2. 加载模型(首次会下载,之后会使用缓存)
        ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
        
        // 使用模型...
        model.close();
    }
}

方法3:使用 Hugging Face Tokenizers

使用 Hugging Face Tokenizers 进行文本处理:

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import java.nio.file.Paths;

public class HuggingFaceTokenizerExample {
    public static void main(String[] args) throws Exception {
        // 1. 加载分词器
        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder()
            .optTokenizerPath(Paths.get("tokenizer.json"))
            .optPadding(true)
            .optMaxLength(128)
            .optTruncation(true)
            .build();
        
        // 2. 分词
        String text = "Hello, how are you?";
        Encoding encoding = tokenizer.encode(text);
        
        // 3. 获取结果
        System.out.println("原始文本: " + text);
        System.out.println("Token IDs: " + Arrays.toString(encoding.getIds()));
        System.out.println("Attention Mask: " + Arrays.toString(encoding.getAttentionMask()));
        System.out.println("Token 数量: " + encoding.getIds().length);
        
        // 4. 批量编码
        List<String> texts = Arrays.asList(
            "Hello, world!",
            "How are you?",
            "I love Java!"
        );
        
        List<Encoding> encodings = tokenizer.batchEncode(texts);
        for (int i = 0; i < encodings.size(); i++) {
            System.out.println("\n文本 " + (i + 1) + ": " + texts.get(i));
            System.out.println("Token IDs: " + Arrays.toString(encodings.get(i).getIds()));
        }
    }
}

第六部分:自定义模型和转换器

自定义转换器详解

创建完整的图像预处理转换器:

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;

public class CustomImageTranslator implements Translator<Image, Classifications> {
    
    private int imageSize;
    private float[] mean;
    private float[] std;
    
    public CustomImageTranslator() {
        this(224, 
             new float[]{0.485f, 0.456f, 0.406f}, 
             new float[]{0.229f, 0.224f, 0.225f});
    }
    
    public CustomImageTranslator(int imageSize, float[] mean, float[] std) {
        this.imageSize = imageSize;
        this.mean = mean;
        this.std = std;
    }
    
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        
        // 1. 调整大小(保持宽高比)
        Image resized = input.resize(imageSize, imageSize);
        
        // 2. 转换为 NDArray [H, W, C]
        NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
        
        // 3. 转换为 [C, H, W] 格式
        array = array.transpose(2, 0, 1);
        
        // 4. 归一化到 [0, 1]
        array = array.div(255.0f);
        
        // 5. 标准化(减去均值,除以标准差)
        NDArray meanArray = manager.create(mean).reshape(3, 1, 1);
        NDArray stdArray = manager.create(std).reshape(3, 1, 1);
        array = array.sub(meanArray).div(stdArray);
        
        // 6. 添加批次维度 [1, C, H, W]
        array = array.expandDims(0);
        
        return new NDList(array);
    }
    
    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray output = list.singletonOrThrow();
        
        // 1. 移除批次维度(如果有)
        if (output.getShape().dimension() > 1) {
            output = output.squeeze(0);
        }
        
        // 2. 应用 softmax
        output = output.softmax(0);
        
        // 3. 转换为 Classifications
        return new Classifications(output);
    }
    
    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
}

自定义模型训练

创建简单的神经网络模型:

import ai.djl.Model;
import ai.djl.nn.*;
import ai.djl.nn.core.Linear;
import ai.djl.nn.Activation;
import ai.djl.nn.SequentialBlock;
import ai.djl.ndarray.types.Shape;
import java.nio.file.Paths;

public class CustomModelExample {
    public static void main(String[] args) throws Exception {
        // 1. 创建模型架构
        SequentialBlock block = new SequentialBlock()
            .add(Linear.builder().setUnits(128).build())
            .add(Activation::relu)
            .add(Linear.builder().setUnits(64).build())
            .add(Activation::relu)
            .add(Linear.builder().setUnits(10).build())
            .add(Activation::softmax);
        
        // 2. 创建模型实例
        Model model = Model.newInstance("custom_model");
        model.setBlock(block);
        
        // 3. 初始化参数
        model.getBlock().initialize(
            model.getNDManager(), 
            new Shape(1, 784)  // 输入形状:[batch_size, features]
        );
        
        // 4. 训练模型(示例代码,实际需要数据集和训练循环)
        // trainModel(model);
        
        // 5. 保存模型
        model.save(Paths.get("models"), "custom_model");
        
        System.out.println("模型保存成功");
        
        // 6. 关闭模型
        model.close();
    }
}

第七部分:实际应用案例

案例1:图像背景移除(BEN2 模型)

使用 BEN2 模型移除图像背景:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import java.nio.file.Paths;

public class BEN2BackgroundRemoval {
    
    // 自定义转换器
    static class BEN2Translator implements Translator<Image, Image> {
        @Override
        public NDList processInput(TranslatorContext ctx, Image input) {
            NDManager manager = ctx.getNDManager();
            
            // 1. 调整大小到 1024x1024
            Image resized = input.resize(1024, 1024);
            
            // 2. 转换为 NDArray [H, W, C]
            NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
            
            // 3. 归一化到 [0, 1]
            array = array.div(255.0f);
            
            // 4. 转换为 [C, H, W] 格式并添加批次维度
            array = array.transpose(2, 0, 1).expandDims(0);
            
            return new NDList(array);
        }
        
        @Override
        public Image processOutput(TranslatorContext ctx, NDList list) {
            NDArray output = list.singletonOrThrow();
            
            // 1. 移除批次维度并转换回 [H, W, C]
            output = output.squeeze(0).transpose(1, 2, 0);
            
            // 2. 反归一化
            output = output.mul(255.0f).clip(0, 255);
            
            // 3. 转换为 Image
            return ImageFactory.getInstance().fromNDArray(output);
        }
    }
    
    public static void main(String[] args) throws Exception {
        // 1. 定义模型标准
        Criteria<Image, Image> criteria = Criteria.builder()
            .setTypes(Image.class, Image.class)
            .optModelPath(Paths.get("models/ben2"))
            .optModelName("ben2")
            .optEngine("PyTorch")
            .optTranslator(new BEN2Translator())
            .build();
        
        // 2. 加载模型
        try (ZooModel<Image, Image> model = ModelZoo.loadModel(criteria);
             Predictor<Image, Image> predictor = model.newPredictor()) {
            
            // 3. 加载输入图像
            Image inputImage = ImageFactory.getInstance()
                .fromFile(Paths.get("input.jpg"));
            
            // 4. 移除背景
            Image foreground = predictor.predict(inputImage);
            
            // 5. 保存结果
            foreground.save(Paths.get("foreground.png"), "png");
            
            System.out.println("背景移除完成!");
        }
    }
}

案例2:文本情感分析(完整实现)

完整的文本情感分析系统:

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;

public class SentimentAnalysisSystem {
    
    static class SentimentTranslator implements Translator<String, Classifications> {
        private HuggingFaceTokenizer tokenizer;
        
        public SentimentTranslator(String tokenizerPath) throws Exception {
            tokenizer = HuggingFaceTokenizer.builder()
                .optTokenizerPath(Paths.get(tokenizerPath))
                .optPadding(true)
                .optMaxLength(128)
                .optTruncation(true)
                .build();
        }
        
        @Override
        public NDList processInput(TranslatorContext ctx, String input) {
            NDManager manager = ctx.getNDManager();
            
            // 分词
            Encoding encoding = tokenizer.encode(input);
            
            // 转换为 NDArray
            long[] inputIds = encoding.getIds();
            long[] attentionMask = encoding.getAttentionMask();
            
            NDArray inputIdsArray = manager.create(inputIds).expandDims(0);
            NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
            
            return new NDList(inputIdsArray, attentionMaskArray);
        }
        
        @Override
        public Classifications processOutput(TranslatorContext ctx, NDList list) {
            NDArray output = list.singletonOrThrow();
            
            // 应用 softmax
            output = output.softmax(1);
            
            // 转换为 Classifications
            return new Classifications(output, 
                Arrays.asList("negative", "positive"));
        }
    }
    
    public static void main(String[] args) throws Exception {
        // 1. 初始化系统
        System.out.println("初始化情感分析系统...");
        
        Criteria<String, Classifications> criteria = Criteria.builder()
            .setTypes(String.class, Classifications.class)
            .optModelPath(Paths.get("models/sentiment_model"))
            .optModelName("distilbert-sentiment")
            .optEngine("PyTorch")
            .optTranslator(new SentimentTranslator("tokenizer.json"))
            .build();
        
        try (ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
             Predictor<String, Classifications> predictor = model.newPredictor()) {
            
            System.out.println("系统初始化完成!\n");
            
            // 2. 测试文本
            List<String> testTexts = Arrays.asList(
                "I love Java programming!",
                "This is terrible.",
                "The weather is nice today.",
                "I hate this product.",
                "This is amazing!"
            );
            
            // 3. 批量分析
            System.out.println("开始情感分析...\n");
            for (String text : testTexts) {
                Classifications result = predictor.predict(text);
                
                String sentiment = result.best().getClassName();
                double confidence = result.best().getProbability();
                
                // 格式化输出
                System.out.println("文本: " + text);
                System.out.println("情感: " + sentiment.toUpperCase());
                System.out.println("置信度: " + String.format("%.2f%%", confidence * 100));
                
                // 显示所有类别
                System.out.println("详细结果:");
                result.items().forEach(item -> 
                    System.out.println("  " + item.getClassName() + ": " + 
                        String.format("%.2f%%", item.getProbability() * 100))
                );
                System.out.println("---");
            }
        }
    }
}

案例3:Spring Boot 集成

在 Spring Boot 应用中集成 DJL:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

@Service
public class ImageClassificationService {
    
    @Autowired
    private ModelLoader modelLoader;  // 自定义的模型加载器
    
    private ZooModel<Image, Classifications> model;
    private Predictor<Image, Classifications> predictor;
    
    @PostConstruct
    public void init() throws Exception {
        // 初始化时加载模型
        model = modelLoader.loadModel();
        predictor = model.newPredictor();
    }
    
    @PreDestroy
    public void destroy() {
        // 关闭时释放资源
        if (predictor != null) {
            predictor.close();
        }
        if (model != null) {
            model.close();
        }
    }
    
    public ClassificationResult classifyImage(MultipartFile file) throws Exception {
        // 1. 保存临时文件
        Path tempFile = Files.createTempFile("upload", ".jpg");
        file.transferTo(tempFile.toFile());
        
        try {
            // 2. 加载图像
            Image image = ImageFactory.getInstance().fromFile(tempFile);
            
            // 3. 进行预测
            Classifications classifications = predictor.predict(image);
            
            // 4. 构建结果
            ClassificationResult result = new ClassificationResult();
            result.setTopPrediction(
                classifications.best().getClassName(),
                classifications.best().getProbability()
            );
            result.setTopK(classifications.topK(5));
            
            return result;
        } finally {
            // 5. 删除临时文件
            Files.deleteIfExists(tempFile);
        }
    }
}

// REST Controller
@RestController
@RequestMapping("/api/classify")
public class ClassificationController {
    
    @Autowired
    private ImageClassificationService service;
    
    @PostMapping("/image")
    public ResponseEntity<ClassificationResult> classifyImage(
            @RequestParam("file") MultipartFile file) {
        try {
            ClassificationResult result = service.classifyImage(file);
            return ResponseEntity.ok(result);
        } catch (Exception e) {
            return ResponseEntity.status(500).build();
        }
    }
}

第八部分:DJL 的高级特性

1. GPU 支持与设备管理

使用 GPU 加速推理:

import ai.djl.engine.Engine;
import ai.djl.Device;
import ai.djl.repository.zoo.Criteria;

public class GPUSupportExample {
    public static void main(String[] args) {
        // 1. 检查 GPU 可用性
        Engine engine = Engine.getInstance();
        boolean hasGPU = engine.hasCapability("CUDA");
        
        System.out.println("GPU 可用: " + hasGPU);
        
        // 2. 获取 GPU 设备
        if (hasGPU) {
            // 使用第一个 GPU
            Device gpu0 = Device.gpu(0);
            
            // 使用第二个 GPU(如果有)
            // Device gpu1 = Device.gpu(1);
            
            // 3. 在模型标准中指定设备
            Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optDevice(gpu0)  // 指定使用 GPU
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .build();
            
            System.out.println("将使用 GPU 进行推理");
        } else {
            // 使用 CPU
            Device cpu = Device.cpu();
            
            Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optDevice(cpu)  // 指定使用 CPU
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .build();
            
            System.out.println("将使用 CPU 进行推理");
        }
    }
}

2. 模型管理与缓存

管理模型缓存和版本:

import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Path;
import java.nio.file.Paths;

public class ModelManagementExample {
    public static void main(String[] args) throws Exception {
        // 1. 指定模型缓存目录
        Path modelDir = Paths.get("./models_cache");
        
        Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optApplication(Application.CV.IMAGE_CLASSIFICATION)
            .optModelPath(modelDir)  // 指定缓存目录
            .build();
        
        // 2. 加载模型(首次会下载,之后使用缓存)
        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
            System.out.println("模型路径: " + model.getModelPath());
            System.out.println("模型名称: " + model.getName());
            
            // 3. 保存模型到指定位置
            Path savePath = Paths.get("./saved_models/my_model");
            model.save(savePath, "my_model");
            
            System.out.println("模型已保存到: " + savePath);
        }
    }
}

3. 性能优化技巧

模型预热和批量处理优化:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;

public class PerformanceOptimizationExample {
    
    public static void warmupModel(ZooModel<Image, Classifications> model) throws Exception {
        System.out.println("预热模型...");
        
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            // 创建虚拟图像进行预热
            Image dummyImage = ImageFactory.getInstance()
                .fromFile(Paths.get("dummy.jpg"));
            
            // 运行多次预热
            for (int i = 0; i < 10; i++) {
                predictor.predict(dummyImage);
            }
        }
        
        System.out.println("模型预热完成");
    }
    
    public static void benchmarkModel(ZooModel<Image, Classifications> model) throws Exception {
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Image testImage = ImageFactory.getInstance()
                .fromFile(Paths.get("test.jpg"));
            
            // 单次推理基准测试
            int warmupRuns = 5;
            int testRuns = 100;
            
            // 预热
            for (int i = 0; i < warmupRuns; i++) {
                predictor.predict(testImage);
            }
            
            // 测试
            long totalTime = 0;
            for (int i = 0; i < testRuns; i++) {
                long startTime = System.nanoTime();
                predictor.predict(testImage);
                long endTime = System.nanoTime();
                totalTime += (endTime - startTime);
            }
            
            double avgTime = totalTime / (double) testRuns / 1_000_000; // 转换为毫秒
            System.out.println("平均推理时间: " + String.format("%.2f", avgTime) + "ms");
            System.out.println("吞吐量: " + String.format("%.2f", 1000 / avgTime) + " 次/秒");
        }
    }
}

4. 多线程推理

使用线程池进行并发推理:

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.Collectors;

public class MultiThreadInferenceExample {
    
    public static List<Classifications> predictBatch(
            ZooModel<Image, Classifications> model,
            List<Image> images,
            int threadPoolSize) throws Exception {
        
        // 1. 创建线程池
        ExecutorService executor = Executors.newFixedThreadPool(threadPoolSize);
        
        try {
            // 2. 为每个线程创建独立的 Predictor
            List<Future<Classifications>> futures = images.stream()
                .map(image -> executor.submit(() -> {
                    try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
                        return predictor.predict(image);
                    }
                }))
                .collect(Collectors.toList());
            
            // 3. 收集结果
            return futures.stream()
                .map(future -> {
                    try {
                        return future.get();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                })
                .collect(Collectors.toList());
        } finally {
            executor.shutdown();
        }
    }
}

第九部分:最佳实践

1. 资源管理

使用 try-with-resources 确保资源正确释放:

// ✅ 正确:自动关闭资源
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
     Predictor<Image, Classifications> predictor = model.newPredictor()) {
    // 使用模型
    Classifications result = predictor.predict(image);
}

// ❌ 错误:手动管理资源(容易忘记关闭)
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor();
// 使用后需要手动关闭
predictor.close();
model.close();

2. 错误处理

完善的错误处理机制:

import ai.djl.ModelException;
import ai.djl.translate.TranslateException;

public class ErrorHandlingExample {
    public static void classifyImage(String imagePath) {
        try {
            ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
            // 使用模型...
        } catch (ModelException e) {
            System.err.println("模型加载失败: " + e.getMessage());
            // 记录日志
            logger.error("模型加载失败", e);
            // 使用备用模型或返回默认结果
        } catch (TranslateException e) {
            System.err.println("推理失败: " + e.getMessage());
            logger.error("推理失败", e);
        } catch (Exception e) {
            System.err.println("发生未知错误: " + e.getMessage());
            logger.error("未知错误", e);
        }
    }
}

3. 日志记录

配置 DJL 日志:

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LoggingExample {
    private static final Logger logger = LoggerFactory.getLogger(LoggingExample.class);
    
    public static void main(String[] args) {
        // 设置 DJL 日志级别
        System.setProperty("ai.djl.logging.level", "INFO");
        
        logger.info("开始加载模型...");
        try {
            ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
            logger.info("模型加载成功");
        } catch (Exception e) {
            logger.error("模型加载失败", e);
        }
    }
}

4. 配置管理

使用配置文件管理模型参数:

import java.util.Properties;
import java.io.InputStream;

public class ConfigExample {
    public static Properties loadConfig() throws Exception {
        Properties props = new Properties();
        try (InputStream is = ConfigExample.class
                .getResourceAsStream("/djl.properties")) {
            props.load(is);
        }
        return props;
    }
    
    public static Criteria<Image, Classifications> buildCriteriaFromConfig() 
            throws Exception {
        Properties config = loadConfig();
        
        return Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optApplication(Application.CV.IMAGE_CLASSIFICATION)
            .optEngine(config.getProperty("engine", "PyTorch"))
            .optModelPath(Paths.get(config.getProperty("model.path", "./models")))
            .build();
    }
}

第十部分:常见问题与解决方案

问题1:找不到引擎

错误信息:

No engine found for: PyTorch

解决方案:

  1. 检查依赖配置:
<!-- 确保添加了引擎依赖 -->
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.25.0</version>
</dependency>

<!-- 以及原生库 -->
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <version>2.0.1</version>
    <classifier>${os}-x86_64</classifier>
</dependency>
  1. 检查操作系统分类器:
<properties>
    <os.detected.classifier>${os.detected.name}-${os.detected.arch}</os.detected.classifier>
</properties>

问题2:模型加载失败

解决方案:

// 1. 检查模型路径
Path modelPath = Paths.get("models/my_model");
if (!Files.exists(modelPath)) {
    System.err.println("模型路径不存在: " + modelPath);
    return;
}

// 2. 检查模型格式
// 确保模型格式与引擎匹配(PyTorch 模型使用 PyTorch 引擎)

// 3. 使用正确的引擎
Criteria<Image, Classifications> criteria = Criteria.builder()
    .setTypes(Image.class, Classifications.class)
    .optModelPath(modelPath)
    .optEngine("PyTorch")  // 明确指定引擎
    .build();

问题3:内存不足

解决方案:

// 1. 减少批次大小
// 2. 使用 CPU 而不是 GPU(如果 GPU 内存不足)
Device device = Device.cpu();

// 3. 及时释放资源
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
    // 使用后自动关闭
}

// 4. 增加 JVM 内存
// java -Xmx4g YourApplication

// 5. 使用模型量化(如果支持)

问题4:GPU 不可用

解决方案:

// 检查 GPU 可用性
Engine engine = Engine.getInstance();
if (engine.hasCapability("CUDA")) {
    Device device = Device.gpu(0);
    // 使用 GPU
} else {
    Device device = Device.cpu();
    // 使用 CPU
    System.out.println("GPU 不可用,使用 CPU");
}

问题5:性能问题

优化建议:

  1. 模型预热: 在正式使用前运行几次推理
  2. 批量处理: 使用批量处理提高吞吐量
  3. 使用 GPU: 如果可用,使用 GPU 加速
  4. 模型量化: 使用量化模型减少内存和计算量
  5. 多线程: 使用多线程进行并发推理

第十一部分:DJL vs 其他方案

对比 ONNX Runtime

特性DJLONNX Runtime
框架支持✅ 多种框架⚠️ 仅 ONNX
统一 API✅ 是⚠️ 仅 ONNX
Hugging Face✅ 直接支持❌ 需转换
模型训练✅ 支持❌ 仅推理
易用性⭐⭐⭐⭐⭐⭐⭐⭐⭐
性能⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐

选择建议:

  • 如果需要多种框架支持 → 选择 DJL
  • 如果只需要 ONNX 模型推理 → 选择 ONNX Runtime

对比 Deeplearning4j

特性DJLDeeplearning4j
框架支持✅ 多种框架❌ 仅自己的实现
Hugging Face✅ 直接支持❌ 不支持
API 设计⭐⭐⭐⭐⭐⭐⭐⭐
易用性⭐⭐⭐⭐⭐⭐⭐⭐
社区支持⭐⭐⭐⭐ AWS 支持⭐⭐⭐

选择建议:

  • 如果需要使用预训练模型(特别是 Hugging Face) → 选择 DJL
  • 如果需要完全 Java 实现的框架 → 选择 Deeplearning4j

对比 Smile

特性DJLSmile
深度学习✅ 支持❌ 不支持
传统 ML⚠️ 有限✅ 支持
Hugging Face✅ 支持❌ 不支持
易用性⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐

选择建议:

  • 如果需要深度学习 → 选择 DJL
  • 如果只需要传统机器学习 → 选择 Smile

第十二部分:总结

DJL 的核心优势

  1. 框架无关:一套代码支持所有框架,灵活切换
  2. 统一 API:学习成本低,一套 API 掌握所有框架
  3. Hugging Face 集成:可以直接加载模型,无需转换
  4. 易于使用:Java 风格,易于集成到现有项目
  5. 生产就绪:企业级支持,稳定可靠
  6. 高性能:直接调用底层框架,性能优秀
  7. 跨平台:支持多种操作系统和硬件

适用场景

适合使用 DJL:

  • ✅ Java 项目需要深度学习功能
  • ✅ 需要加载 Hugging Face 模型
  • ✅ 需要支持多种框架
  • ✅ 需要统一的 API
  • ✅ 企业级应用
  • ✅ 微服务架构中的模型服务

不适合使用 DJL:

  • ❌ 只需要传统机器学习(使用 Smile)
  • ❌ 只需要 ONNX 模型推理(使用 ONNX Runtime)
  • ❌ Python 项目(使用原生框架更合适)

学习路径建议

  1. 基础阶段:

    • 了解 DJL 的核心概念(Engine、Model、Translator、Predictor)
    • 完成简单的图像分类示例
    • 理解资源管理(try-with-resources)
  2. 进阶阶段:

    • 学习自定义转换器
    • 掌握 Hugging Face 模型加载
    • 了解 GPU 支持和性能优化
  3. 高级阶段:

    • 自定义模型训练
    • 生产环境部署
    • 性能调优和故障排查

类比总结

理解 DJL,就像理解为什么使用统一接口:

  • 框架无关:就像 USB-C 接口,可以连接任何设备
  • 统一 API:就像标准插座,任何电器都能用
  • 易于使用:就像通用遥控器,一个控制所有
  • Hugging Face 集成:就像应用商店,直接下载使用
  • 生产就绪:就像企业级软件,稳定可靠

掌握 DJL,你就能在 Java 生态系统中轻松使用深度学习!


参考资料


评论