DJL 与引擎无关的深度学习 Java 框架
引言
想象一下,你要在 Java 项目中集成深度学习模型:
- 方法1:为每个框架(PyTorch、TensorFlow、MXNet)分别写代码,维护多套代码(工作量大、维护困难)
- 方法2:使用统一的 API,一套代码支持所有框架,轻松切换和部署
DJL(Deep Java Library) 就像方法2——它是 AWS 开发的统一深度学习接口,让你可以在 Java 中使用 PyTorch、TensorFlow、MXNet、ONNX Runtime、PaddlePaddle 等所有主流深度学习框架,无需关心底层实现细节。
本文将用生动的类比、详细的代码示例和实际应用场景,带你从零开始,深入掌握 DJL 的强大功能,让你能够在 Java 生态系统中轻松进行深度学习开发。
第一部分:什么是 DJL?
DJL 的直观理解
DJL(Deep Java Library) 是 AWS 开发的开源深度学习库,为 Java 开发者提供了统一的 API 来使用各种深度学习框架。
类比理解:
- 就像统一接口:一套 API 支持所有深度学习框架,就像 USB-C 接口可以连接各种设备
- 就像适配器模式:将不同框架的差异隐藏起来,就像电源适配器适配不同电压
- 就像翻译器:将 Java 代码翻译成不同框架的调用,就像同声传译员
- 就像万能钥匙:一把钥匙打开所有框架的门,就像万能钥匙可以开所有锁
DJL 的核心特点
- 框架无关:支持 PyTorch、TensorFlow、MXNet、ONNX Runtime、PaddlePaddle
- 统一 API:一套代码,支持所有框架,无需修改代码即可切换框架
- 易于使用:Java 风格的 API,易于学习和使用,符合 Java 开发习惯
- 高性能:直接调用底层框架,性能优秀,接近原生框架性能
- Hugging Face 集成:可以直接加载 Hugging Face 模型,无需转换
- 生产就绪:被 AWS 等公司广泛使用,企业级支持
- 跨平台:支持 Windows、Linux、macOS 等操作系统
- GPU 支持:支持 CUDA、MPS 等 GPU 加速
类比理解:
- 就像统一遥控器:一个遥控器控制所有电器,不需要为每个电器准备一个遥控器
- 就像多语言翻译器:可以理解多种语言,自动翻译成目标语言
- 就像通用接口:标准化的接口设计,任何设备都能连接
为什么选择 DJL?
与其他 Java 深度学习库的对比:
| 特性 | DJL | Deeplearning4j | Smile | ONNX Runtime |
|---|---|---|---|---|
| 框架支持 | ⭐⭐⭐⭐⭐ 多种框架 | ⭐⭐ 仅自己的实现 | ⭐⭐ 传统 ML | ⭐⭐⭐⭐ 仅 ONNX |
| 统一 API | ✅ 是 | ❌ 否 | ❌ 否 | ⚠️ 仅 ONNX |
| Hugging Face | ✅ 直接支持 | ❌ 不支持 | ❌ 不支持 | ⚠️ 需转换 |
| 模型训练 | ✅ 支持 | ✅ 支持 | ✅ 支持 | ❌ 仅推理 |
| 易用性 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| 性能 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| 社区支持 | ⭐⭐⭐⭐ AWS 支持 | ⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| 文档质量 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
选择 DJL 的理由:
- ✅ 框架无关:一套代码支持所有框架,灵活切换
- ✅ 统一 API:学习成本低,一套 API 掌握所有框架
- ✅ Hugging Face 集成:可以直接加载模型,无需转换
- ✅ 易于集成:Java 风格,易于集成到现有项目
- ✅ AWS 支持:企业级支持,持续更新
- ✅ 生产就绪:被广泛使用,稳定可靠
DJL 的适用场景
适合使用 DJL:
- ✅ Java 项目需要深度学习功能
- ✅ 需要加载 Hugging Face 模型
- ✅ 需要支持多种框架
- ✅ 需要统一的 API
- ✅ 企业级应用
- ✅ 微服务架构中的模型服务
不适合使用 DJL:
- ❌ 只需要传统机器学习(使用 Smile)
- ❌ 只需要 ONNX 模型推理(使用 ONNX Runtime)
- ❌ Python 项目(使用原生框架更合适)
- ❌ 需要复杂的模型训练(Python 生态更丰富)
第二部分:DJL 的安装与配置
系统要求
Java 版本:
- Java 8 或更高版本(推荐 Java 11+)
操作系统:
- Windows 10/11
- Linux(Ubuntu、CentOS 等)
- macOS 10.14+
内存:
- 至少 4GB RAM(推荐 8GB+)
- GPU 内存根据模型大小而定
Maven 依赖配置
1. 基础依赖(必需):
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.25.0</version>
</dependency>
2. 选择引擎(至少选择一个):
<!-- PyTorch 引擎(最常用,推荐) -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.25.0</version>
</dependency>
<!-- TensorFlow 引擎 -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>0.25.0</version>
</dependency>
<!-- ONNX Runtime 引擎 -->
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.25.0</version>
</dependency>
<!-- MXNet 引擎 -->
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>0.25.0</version>
</dependency>
3. PyTorch 原生库(CPU 版本):
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>${os}-x86_64</classifier>
</dependency>
4. PyTorch 原生库(GPU 版本 - CUDA 11.8):
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cu118</artifactId>
<version>2.0.1</version>
<classifier>${os}-x86_64</classifier>
</dependency>
5. Hugging Face 集成(可选):
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.25.0</version>
</dependency>
6. 图像处理支持(可选):
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.25.0</version>
</dependency>
完整的 pom.xml 示例:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>djl-demo</artifactId>
<version>1.0.0</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<os.detected.classifier>${os.detected.name}-${os.detected.arch}</os.detected.classifier>
</properties>
<dependencies>
<!-- DJL API -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.25.0</version>
</dependency>
<!-- PyTorch 引擎 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.25.0</version>
</dependency>
<!-- PyTorch 原生库(CPU) -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>${os.detected.classifier}</classifier>
</dependency>
<!-- Model Zoo -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.25.0</version>
</dependency>
<!-- Hugging Face Tokenizers -->
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.25.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
Gradle 依赖配置
build.gradle 示例:
plugins {
id 'java'
}
repositories {
mavenCentral()
}
dependencies {
implementation 'ai.djl:api:0.25.0'
implementation 'ai.djl.pytorch:pytorch-engine:0.25.0'
implementation 'ai.djl.pytorch:pytorch-native-cpu:2.0.1:linux-x86_64'
implementation 'ai.djl:model-zoo:0.25.0'
implementation 'ai.djl.huggingface:tokenizers:0.25.0'
}
java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
验证安装
创建测试类验证安装:
import ai.djl.engine.Engine;
import java.util.List;
public class DJLInstallationTest {
public static void main(String[] args) {
System.out.println("=== DJL 安装验证 ===");
// 1. 检查 DJL 版本
System.out.println("DJL 版本: " + Engine.getDjlVersion());
// 2. 获取所有可用引擎
List<Engine> engines = Engine.getAllEngines();
System.out.println("\n可用引擎数量: " + engines.size());
for (Engine engine : engines) {
System.out.println("\n引擎信息:");
System.out.println(" 名称: " + engine.getEngineName());
System.out.println(" 版本: " + engine.getVersion());
System.out.println(" GPU 支持: " + engine.hasCapability("CUDA"));
}
// 3. 获取默认引擎
Engine defaultEngine = Engine.getInstance();
System.out.println("\n默认引擎: " + defaultEngine.getEngineName());
// 4. 检查 GPU 可用性
if (defaultEngine.hasCapability("CUDA")) {
System.out.println("✓ GPU 可用");
} else {
System.out.println("✗ GPU 不可用,将使用 CPU");
}
System.out.println("\n=== 安装验证完成 ===");
}
}
运行结果示例:
=== DJL 安装验证 ===
DJL 版本: 0.25.0
可用引擎数量: 1
引擎信息:
名称: PyTorch
版本: 2.0.1
GPU 支持: false
默认引擎: PyTorch
✗ GPU 不可用,将使用 CPU
=== 安装验证完成 ===
类比理解:
- 就像安装软件:添加依赖 → 验证安装 → 开始使用
- 就像准备工具:拿到工具 → 检查工具 → 开始工作
- 就像体检:检查各项指标 → 确认健康 → 开始工作
第三部分:DJL 的核心概念
1. Engine(引擎)
引擎是 DJL 对不同深度学习框架的抽象。每个引擎对应一个深度学习框架。
支持的引擎:
- PyTorch:最常用,支持 Hugging Face 模型,推荐用于新项目
- TensorFlow:Google 的框架,适合 TensorFlow 模型
- MXNet:Apache 的框架,轻量级
- ONNX Runtime:跨平台推理引擎,性能优秀
- PaddlePaddle:百度的框架,适合中文场景
类比理解:
- 就像驱动程序:不同的硬件需要不同的驱动,不同的框架需要不同的引擎
- 就像翻译器:不同的语言需要不同的翻译器,不同的框架需要不同的引擎
- 就像适配器:不同的接口需要不同的适配器
使用示例:
import ai.djl.engine.Engine;
import java.util.List;
public class EngineExample {
public static void main(String[] args) {
// 1. 获取默认引擎
Engine engine = Engine.getInstance();
System.out.println("默认引擎: " + engine.getEngineName());
System.out.println("引擎版本: " + engine.getVersion());
// 2. 获取所有可用引擎
List<Engine> engines = Engine.getAllEngines();
System.out.println("\n所有可用引擎:");
for (Engine e : engines) {
System.out.println(" - " + e.getEngineName() + " " + e.getVersion());
}
// 3. 指定使用特定引擎
Engine pytorch = Engine.getEngine("PyTorch");
if (pytorch != null) {
System.out.println("\nPyTorch 引擎可用");
System.out.println(" GPU 支持: " + pytorch.hasCapability("CUDA"));
}
// 4. 检查引擎能力
Engine defaultEngine = Engine.getInstance();
System.out.println("\n引擎能力:");
System.out.println(" CUDA: " + defaultEngine.hasCapability("CUDA"));
System.out.println(" MKL: " + defaultEngine.hasCapability("MKL"));
}
}
2. Model(模型)
模型是 DJL 对机器学习模型的抽象。模型可以是预训练的,也可以是自己训练的。
模型类型:
- ZooModel:从 ModelZoo 加载的预训练模型
- Model:自定义训练的模型
类比理解:
- 就像容器:装载训练好的模型
- 就像包装盒:包装不同格式的模型
- 就像程序:封装了算法逻辑
加载模型示例:
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.Application;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;
public class ModelExample {
public static void main(String[] args) throws Exception {
// 1. 定义模型标准
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optFilter("dataset", "imagenet")
.optFilter("layers", "50")
.optEngine("PyTorch")
.build();
// 2. 加载模型
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
System.out.println("模型加载成功");
System.out.println("模型名称: " + model.getName());
System.out.println("模型路径: " + model.getModelPath());
}
}
}
3. Translator(转换器)
转换器负责将输入数据转换为模型需要的格式,以及将模型输出转换为 Java 对象。
转换器的作用:
- 输入转换:将 Java 对象(如 Image、String)转换为 NDArray
- 输出转换:将 NDArray 转换为 Java 对象(如 Classifications、String)
类比理解:
- 就像翻译器:将 Java 数据翻译成模型格式
- 就像适配器:适配不同格式的数据
- 就像转换器:将一种格式转换为另一种格式
自定义转换器示例:
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;
public class ImageClassificationTranslator implements Translator<Image, Classifications> {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 1. 调整图像大小
Image resized = input.resize(224, 224);
// 2. 转换为 NDArray
NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
// 3. 归一化(ImageNet 标准)
array = array.div(255.0f);
array = array.sub(new float[]{0.485f, 0.456f, 0.406f});
array = array.div(new float[]{0.229f, 0.224f, 0.225f});
// 4. 添加批次维度 [1, 3, 224, 224]
array = array.expandDims(0);
return new NDList(array);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// 1. 应用 softmax
output = output.softmax(1);
// 2. 转换为 Classifications
return new Classifications(output);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
4. Predictor(预测器)
预测器用于运行模型进行预测。它是模型和转换器的组合。
类比理解:
- 就像执行器:执行模型推理
- 就像计算器:输入数据,输出结果
- 就像预测器:根据输入预测输出
使用示例:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
public class PredictorExample {
public static void main(String[] args) throws Exception {
// 假设已经加载了模型
ZooModel<Image, Classifications> model = loadModel();
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 1. 加载图像
Image image = ImageFactory.getInstance()
.fromUrl("https://example.com/cat.jpg");
// 2. 进行预测
Classifications classifications = predictor.predict(image);
// 3. 显示结果
System.out.println("预测结果:");
classifications.topK(5).forEach(item ->
System.out.println(String.format(
" %s: %.2f%%",
item.getClassName(),
item.getProbability() * 100
))
);
}
}
private static ZooModel<Image, Classifications> loadModel() {
// 加载模型的代码
return null;
}
}
5. NDArray(多维数组)
NDArray是 DJL 的核心数据结构,类似于 NumPy 的 ndarray。
NDArray 的特点:
- 多维数组:支持任意维度的数组
- 类型支持:支持 float、double、int、long 等类型
- 操作丰富:支持各种数学运算、形状变换等
- 内存管理:自动管理内存,避免内存泄漏
类比理解:
- 就像NumPy 数组:Python 中的多维数组
- 就像矩阵:数学中的矩阵运算
- 就像张量:深度学习中的基本数据结构
NDArray 使用示例:
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
public class NDArrayExample {
public static void main(String[] args) {
try (NDManager manager = NDManager.newBaseManager()) {
// 1. 创建数组
NDArray array = manager.create(new float[]{1, 2, 3, 4, 5, 6});
System.out.println("原始数组: " + array);
// 2. 改变形状
NDArray reshaped = array.reshape(2, 3);
System.out.println("重塑后: " + reshaped);
// 3. 数学运算
NDArray doubled = array.mul(2);
System.out.println("乘以2: " + doubled);
// 4. 创建零数组
NDArray zeros = manager.zeros(new Shape(3, 4));
System.out.println("零数组:\n" + zeros);
// 5. 创建随机数组
NDArray random = manager.randomNormal(new Shape(2, 3));
System.out.println("随机数组:\n" + random);
}
}
}
第四部分:DJL 的基础使用
示例1:图像分类(完整示例)
使用预训练模型进行图像分类:
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.nio.file.Paths;
public class ImageClassificationExample {
public static void main(String[] args) {
try {
// 1. 定义模型标准
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optFilter("dataset", "imagenet")
.optFilter("layers", "50")
.optEngine("PyTorch")
.build();
// 2. 加载模型
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor()) {
System.out.println("模型加载成功: " + model.getName());
// 3. 加载图像(可以从文件、URL 或字节数组加载)
Image image = ImageFactory.getInstance()
.fromFile(Paths.get("path/to/image.jpg"));
// 或者从 URL 加载
// Image image = ImageFactory.getInstance()
// .fromUrl("https://example.com/image.jpg");
// 4. 进行预测
long startTime = System.currentTimeMillis();
Classifications classifications = predictor.predict(image);
long endTime = System.currentTimeMillis();
// 5. 显示结果
System.out.println("\n预测结果(Top 5):");
System.out.println("预测耗时: " + (endTime - startTime) + "ms");
System.out.println("----------------------------------------");
classifications.topK(5).forEach(item ->
System.out.println(String.format(
" %-30s: %.2f%%",
item.getClassName(),
item.getProbability() * 100
))
);
}
} catch (ModelException e) {
System.err.println("模型加载失败: " + e.getMessage());
e.printStackTrace();
} catch (TranslateException e) {
System.err.println("预测失败: " + e.getMessage());
e.printStackTrace();
} catch (Exception e) {
System.err.println("发生错误: " + e.getMessage());
e.printStackTrace();
}
}
}
示例2:文本分类(Hugging Face)
使用 Hugging Face 模型进行文本分类:
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
import java.util.Arrays;
public class TextClassificationExample {
// 自定义转换器
static class TextClassificationTranslator implements Translator<String, Classifications> {
private HuggingFaceTokenizer tokenizer;
public TextClassificationTranslator(String tokenizerPath) throws Exception {
tokenizer = HuggingFaceTokenizer.builder()
.optTokenizerPath(Paths.get(tokenizerPath))
.optPadding(true)
.optMaxLength(128)
.build();
}
@Override
public NDList processInput(TranslatorContext ctx, String input) {
NDManager manager = ctx.getNDManager();
// 1. 分词
Encoding encoding = tokenizer.encode(input);
// 2. 转换为 NDArray
long[] inputIds = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
NDArray inputIdsArray = manager.create(inputIds).expandDims(0);
NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
return new NDList(inputIdsArray, attentionMaskArray);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// 应用 softmax
output = output.softmax(1);
// 转换为 Classifications
return new Classifications(output,
Arrays.asList("negative", "positive"));
}
}
public static void main(String[] args) throws Exception {
// 1. 定义模型标准
Criteria<String, Classifications> criteria = Criteria.builder()
.setTypes(String.class, Classifications.class)
.optModelPath(Paths.get("models/distilbert-sentiment"))
.optModelName("distilbert-sentiment")
.optEngine("PyTorch")
.optTranslator(new TextClassificationTranslator("tokenizer.json"))
.build();
// 2. 加载模型
try (ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<String, Classifications> predictor = model.newPredictor()) {
// 3. 测试文本
String[] texts = {
"I love Java programming!",
"This is terrible.",
"The weather is nice today.",
"I hate this product."
};
// 4. 进行预测
for (String text : texts) {
Classifications result = predictor.predict(text);
System.out.println("文本: " + text);
System.out.println("情感: " + result.best().getClassName());
System.out.println("置信度: " +
String.format("%.2f%%", result.best().getProbability() * 100));
System.out.println("---");
}
}
}
}
示例3:目标检测
使用预训练模型进行目标检测:
import ai.djl.Application;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;
public class ObjectDetectionExample {
public static void main(String[] args) throws Exception {
// 1. 定义模型标准
Criteria<Image, DetectedObjects> criteria = Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optApplication(Application.CV.OBJECT_DETECTION)
.optFilter("backbone", "resnet50")
.optEngine("PyTorch")
.build();
// 2. 加载模型
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria);
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
// 3. 加载图像
Image image = ImageFactory.getInstance()
.fromFile(Paths.get("path/to/image.jpg"));
// 4. 进行检测
DetectedObjects detections = predictor.predict(image);
// 5. 显示结果
System.out.println("检测到 " + detections.getNumberOfObjects() + " 个对象:");
System.out.println("----------------------------------------");
detections.items().forEach(item -> {
System.out.println(String.format(
"类别: %s (%.2f%%)",
item.getClassName(),
item.getProbability() * 100
));
System.out.println(String.format(
"位置: [%.0f, %.0f, %.0f, %.0f]",
item.getBoundingBox().getX(),
item.getBoundingBox().getY(),
item.getBoundingBox().getWidth(),
item.getBoundingBox().getHeight()
));
System.out.println("---");
});
}
}
}
示例4:批量处理
批量处理多个图像:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class BatchProcessingExample {
public static void main(String[] args) throws Exception {
// 假设已经加载了模型
ZooModel<Image, Classifications> model = loadModel();
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 1. 准备批量图像路径
List<Path> imagePaths = Arrays.asList(
Paths.get("image1.jpg"),
Paths.get("image2.jpg"),
Paths.get("image3.jpg")
);
// 2. 加载所有图像
List<Image> images = imagePaths.stream()
.map(path -> {
try {
return ImageFactory.getInstance().fromFile(path);
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
// 3. 批量预测
long startTime = System.currentTimeMillis();
List<Classifications> results = images.stream()
.map(predictor::predict)
.collect(Collectors.toList());
long endTime = System.currentTimeMillis();
// 4. 处理结果
System.out.println("批量处理完成");
System.out.println("处理数量: " + images.size());
System.out.println("总耗时: " + (endTime - startTime) + "ms");
System.out.println("平均耗时: " + (endTime - startTime) / images.size() + "ms/张");
System.out.println("----------------------------------------");
for (int i = 0; i < results.size(); i++) {
Classifications classification = results.get(i);
System.out.println("图像 " + (i + 1) + ":");
System.out.println(" 预测: " + classification.best().getClassName());
System.out.println(" 置信度: " +
String.format("%.2f%%", classification.best().getProbability() * 100));
System.out.println("---");
}
}
}
private static ZooModel<Image, Classifications> loadModel() {
// 加载模型的代码
return null;
}
}
第五部分:加载 Hugging Face 模型
方法1:直接加载 Hugging Face 模型
从本地路径加载 Hugging Face 模型:
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.Application;
import ai.djl.modality.nlp.qa.QAInput;
import java.nio.file.Paths;
public class HuggingFaceModelExample {
public static void main(String[] args) throws Exception {
// 1. 定义模型标准
Criteria<QAInput, String> criteria = Criteria.builder()
.setTypes(QAInput.class, String.class)
.optModelPath(Paths.get("models/bert-base-uncased"))
.optModelName("bert-base-uncased")
.optEngine("PyTorch")
.optApplication(Application.NLP.QUESTION_ANSWER)
.build();
// 2. 加载模型
try (ZooModel<QAInput, String> model = ModelZoo.loadModel(criteria);
Predictor<QAInput, String> predictor = model.newPredictor()) {
// 3. 准备输入
QAInput input = new QAInput(
"What is the capital of France?",
"Paris is the capital of France. It is a beautiful city."
);
// 4. 进行预测
String answer = predictor.predict(input);
System.out.println("问题: " + input.getQuestion());
System.out.println("上下文: " + input.getParagraph());
System.out.println("答案: " + answer);
}
}
}
方法2:从 Hugging Face Hub 下载模型
自动从 Hugging Face Hub 下载模型:
import ai.djl.huggingface.translator.TextClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;
public class HuggingFaceHubExample {
public static void main(String[] args) throws Exception {
// 1. 定义模型标准(会自动从 Hugging Face Hub 下载)
Criteria<String, Classifications> criteria = Criteria.builder()
.setTypes(String.class, Classifications.class)
.optModelName("distilbert-base-uncased-finetuned-sst-2-english")
.optModelPath(Paths.get("./models")) // 下载到本地目录
.optEngine("PyTorch")
.optTranslatorFactory(new TextClassificationTranslatorFactory())
.build();
// 2. 加载模型(首次会下载,之后会使用缓存)
ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
// 使用模型...
model.close();
}
}
方法3:使用 Hugging Face Tokenizers
使用 Hugging Face Tokenizers 进行文本处理:
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import java.nio.file.Paths;
public class HuggingFaceTokenizerExample {
public static void main(String[] args) throws Exception {
// 1. 加载分词器
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder()
.optTokenizerPath(Paths.get("tokenizer.json"))
.optPadding(true)
.optMaxLength(128)
.optTruncation(true)
.build();
// 2. 分词
String text = "Hello, how are you?";
Encoding encoding = tokenizer.encode(text);
// 3. 获取结果
System.out.println("原始文本: " + text);
System.out.println("Token IDs: " + Arrays.toString(encoding.getIds()));
System.out.println("Attention Mask: " + Arrays.toString(encoding.getAttentionMask()));
System.out.println("Token 数量: " + encoding.getIds().length);
// 4. 批量编码
List<String> texts = Arrays.asList(
"Hello, world!",
"How are you?",
"I love Java!"
);
List<Encoding> encodings = tokenizer.batchEncode(texts);
for (int i = 0; i < encodings.size(); i++) {
System.out.println("\n文本 " + (i + 1) + ": " + texts.get(i));
System.out.println("Token IDs: " + Arrays.toString(encodings.get(i).getIds()));
}
}
}
第六部分:自定义模型和转换器
自定义转换器详解
创建完整的图像预处理转换器:
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;
public class CustomImageTranslator implements Translator<Image, Classifications> {
private int imageSize;
private float[] mean;
private float[] std;
public CustomImageTranslator() {
this(224,
new float[]{0.485f, 0.456f, 0.406f},
new float[]{0.229f, 0.224f, 0.225f});
}
public CustomImageTranslator(int imageSize, float[] mean, float[] std) {
this.imageSize = imageSize;
this.mean = mean;
this.std = std;
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 1. 调整大小(保持宽高比)
Image resized = input.resize(imageSize, imageSize);
// 2. 转换为 NDArray [H, W, C]
NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
// 3. 转换为 [C, H, W] 格式
array = array.transpose(2, 0, 1);
// 4. 归一化到 [0, 1]
array = array.div(255.0f);
// 5. 标准化(减去均值,除以标准差)
NDArray meanArray = manager.create(mean).reshape(3, 1, 1);
NDArray stdArray = manager.create(std).reshape(3, 1, 1);
array = array.sub(meanArray).div(stdArray);
// 6. 添加批次维度 [1, C, H, W]
array = array.expandDims(0);
return new NDList(array);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// 1. 移除批次维度(如果有)
if (output.getShape().dimension() > 1) {
output = output.squeeze(0);
}
// 2. 应用 softmax
output = output.softmax(0);
// 3. 转换为 Classifications
return new Classifications(output);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
自定义模型训练
创建简单的神经网络模型:
import ai.djl.Model;
import ai.djl.nn.*;
import ai.djl.nn.core.Linear;
import ai.djl.nn.Activation;
import ai.djl.nn.SequentialBlock;
import ai.djl.ndarray.types.Shape;
import java.nio.file.Paths;
public class CustomModelExample {
public static void main(String[] args) throws Exception {
// 1. 创建模型架构
SequentialBlock block = new SequentialBlock()
.add(Linear.builder().setUnits(128).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(64).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(10).build())
.add(Activation::softmax);
// 2. 创建模型实例
Model model = Model.newInstance("custom_model");
model.setBlock(block);
// 3. 初始化参数
model.getBlock().initialize(
model.getNDManager(),
new Shape(1, 784) // 输入形状:[batch_size, features]
);
// 4. 训练模型(示例代码,实际需要数据集和训练循环)
// trainModel(model);
// 5. 保存模型
model.save(Paths.get("models"), "custom_model");
System.out.println("模型保存成功");
// 6. 关闭模型
model.close();
}
}
第七部分:实际应用案例
案例1:图像背景移除(BEN2 模型)
使用 BEN2 模型移除图像背景:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import java.nio.file.Paths;
public class BEN2BackgroundRemoval {
// 自定义转换器
static class BEN2Translator implements Translator<Image, Image> {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 1. 调整大小到 1024x1024
Image resized = input.resize(1024, 1024);
// 2. 转换为 NDArray [H, W, C]
NDArray array = resized.toNDArray(manager, Image.Flag.COLOR);
// 3. 归一化到 [0, 1]
array = array.div(255.0f);
// 4. 转换为 [C, H, W] 格式并添加批次维度
array = array.transpose(2, 0, 1).expandDims(0);
return new NDList(array);
}
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// 1. 移除批次维度并转换回 [H, W, C]
output = output.squeeze(0).transpose(1, 2, 0);
// 2. 反归一化
output = output.mul(255.0f).clip(0, 255);
// 3. 转换为 Image
return ImageFactory.getInstance().fromNDArray(output);
}
}
public static void main(String[] args) throws Exception {
// 1. 定义模型标准
Criteria<Image, Image> criteria = Criteria.builder()
.setTypes(Image.class, Image.class)
.optModelPath(Paths.get("models/ben2"))
.optModelName("ben2")
.optEngine("PyTorch")
.optTranslator(new BEN2Translator())
.build();
// 2. 加载模型
try (ZooModel<Image, Image> model = ModelZoo.loadModel(criteria);
Predictor<Image, Image> predictor = model.newPredictor()) {
// 3. 加载输入图像
Image inputImage = ImageFactory.getInstance()
.fromFile(Paths.get("input.jpg"));
// 4. 移除背景
Image foreground = predictor.predict(inputImage);
// 5. 保存结果
foreground.save(Paths.get("foreground.png"), "png");
System.out.println("背景移除完成!");
}
}
}
案例2:文本情感分析(完整实现)
完整的文本情感分析系统:
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
public class SentimentAnalysisSystem {
static class SentimentTranslator implements Translator<String, Classifications> {
private HuggingFaceTokenizer tokenizer;
public SentimentTranslator(String tokenizerPath) throws Exception {
tokenizer = HuggingFaceTokenizer.builder()
.optTokenizerPath(Paths.get(tokenizerPath))
.optPadding(true)
.optMaxLength(128)
.optTruncation(true)
.build();
}
@Override
public NDList processInput(TranslatorContext ctx, String input) {
NDManager manager = ctx.getNDManager();
// 分词
Encoding encoding = tokenizer.encode(input);
// 转换为 NDArray
long[] inputIds = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
NDArray inputIdsArray = manager.create(inputIds).expandDims(0);
NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
return new NDList(inputIdsArray, attentionMaskArray);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// 应用 softmax
output = output.softmax(1);
// 转换为 Classifications
return new Classifications(output,
Arrays.asList("negative", "positive"));
}
}
public static void main(String[] args) throws Exception {
// 1. 初始化系统
System.out.println("初始化情感分析系统...");
Criteria<String, Classifications> criteria = Criteria.builder()
.setTypes(String.class, Classifications.class)
.optModelPath(Paths.get("models/sentiment_model"))
.optModelName("distilbert-sentiment")
.optEngine("PyTorch")
.optTranslator(new SentimentTranslator("tokenizer.json"))
.build();
try (ZooModel<String, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<String, Classifications> predictor = model.newPredictor()) {
System.out.println("系统初始化完成!\n");
// 2. 测试文本
List<String> testTexts = Arrays.asList(
"I love Java programming!",
"This is terrible.",
"The weather is nice today.",
"I hate this product.",
"This is amazing!"
);
// 3. 批量分析
System.out.println("开始情感分析...\n");
for (String text : testTexts) {
Classifications result = predictor.predict(text);
String sentiment = result.best().getClassName();
double confidence = result.best().getProbability();
// 格式化输出
System.out.println("文本: " + text);
System.out.println("情感: " + sentiment.toUpperCase());
System.out.println("置信度: " + String.format("%.2f%%", confidence * 100));
// 显示所有类别
System.out.println("详细结果:");
result.items().forEach(item ->
System.out.println(" " + item.getClassName() + ": " +
String.format("%.2f%%", item.getProbability() * 100))
);
System.out.println("---");
}
}
}
}
案例3:Spring Boot 集成
在 Spring Boot 应用中集成 DJL:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@Service
public class ImageClassificationService {
@Autowired
private ModelLoader modelLoader; // 自定义的模型加载器
private ZooModel<Image, Classifications> model;
private Predictor<Image, Classifications> predictor;
@PostConstruct
public void init() throws Exception {
// 初始化时加载模型
model = modelLoader.loadModel();
predictor = model.newPredictor();
}
@PreDestroy
public void destroy() {
// 关闭时释放资源
if (predictor != null) {
predictor.close();
}
if (model != null) {
model.close();
}
}
public ClassificationResult classifyImage(MultipartFile file) throws Exception {
// 1. 保存临时文件
Path tempFile = Files.createTempFile("upload", ".jpg");
file.transferTo(tempFile.toFile());
try {
// 2. 加载图像
Image image = ImageFactory.getInstance().fromFile(tempFile);
// 3. 进行预测
Classifications classifications = predictor.predict(image);
// 4. 构建结果
ClassificationResult result = new ClassificationResult();
result.setTopPrediction(
classifications.best().getClassName(),
classifications.best().getProbability()
);
result.setTopK(classifications.topK(5));
return result;
} finally {
// 5. 删除临时文件
Files.deleteIfExists(tempFile);
}
}
}
// REST Controller
@RestController
@RequestMapping("/api/classify")
public class ClassificationController {
@Autowired
private ImageClassificationService service;
@PostMapping("/image")
public ResponseEntity<ClassificationResult> classifyImage(
@RequestParam("file") MultipartFile file) {
try {
ClassificationResult result = service.classifyImage(file);
return ResponseEntity.ok(result);
} catch (Exception e) {
return ResponseEntity.status(500).build();
}
}
}
第八部分:DJL 的高级特性
1. GPU 支持与设备管理
使用 GPU 加速推理:
import ai.djl.engine.Engine;
import ai.djl.Device;
import ai.djl.repository.zoo.Criteria;
public class GPUSupportExample {
public static void main(String[] args) {
// 1. 检查 GPU 可用性
Engine engine = Engine.getInstance();
boolean hasGPU = engine.hasCapability("CUDA");
System.out.println("GPU 可用: " + hasGPU);
// 2. 获取 GPU 设备
if (hasGPU) {
// 使用第一个 GPU
Device gpu0 = Device.gpu(0);
// 使用第二个 GPU(如果有)
// Device gpu1 = Device.gpu(1);
// 3. 在模型标准中指定设备
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optDevice(gpu0) // 指定使用 GPU
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.build();
System.out.println("将使用 GPU 进行推理");
} else {
// 使用 CPU
Device cpu = Device.cpu();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optDevice(cpu) // 指定使用 CPU
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.build();
System.out.println("将使用 CPU 进行推理");
}
}
}
2. 模型管理与缓存
管理模型缓存和版本:
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Path;
import java.nio.file.Paths;
public class ModelManagementExample {
public static void main(String[] args) throws Exception {
// 1. 指定模型缓存目录
Path modelDir = Paths.get("./models_cache");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optModelPath(modelDir) // 指定缓存目录
.build();
// 2. 加载模型(首次会下载,之后使用缓存)
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
System.out.println("模型路径: " + model.getModelPath());
System.out.println("模型名称: " + model.getName());
// 3. 保存模型到指定位置
Path savePath = Paths.get("./saved_models/my_model");
model.save(savePath, "my_model");
System.out.println("模型已保存到: " + savePath);
}
}
}
3. 性能优化技巧
模型预热和批量处理优化:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.nio.file.Paths;
public class PerformanceOptimizationExample {
public static void warmupModel(ZooModel<Image, Classifications> model) throws Exception {
System.out.println("预热模型...");
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 创建虚拟图像进行预热
Image dummyImage = ImageFactory.getInstance()
.fromFile(Paths.get("dummy.jpg"));
// 运行多次预热
for (int i = 0; i < 10; i++) {
predictor.predict(dummyImage);
}
}
System.out.println("模型预热完成");
}
public static void benchmarkModel(ZooModel<Image, Classifications> model) throws Exception {
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image testImage = ImageFactory.getInstance()
.fromFile(Paths.get("test.jpg"));
// 单次推理基准测试
int warmupRuns = 5;
int testRuns = 100;
// 预热
for (int i = 0; i < warmupRuns; i++) {
predictor.predict(testImage);
}
// 测试
long totalTime = 0;
for (int i = 0; i < testRuns; i++) {
long startTime = System.nanoTime();
predictor.predict(testImage);
long endTime = System.nanoTime();
totalTime += (endTime - startTime);
}
double avgTime = totalTime / (double) testRuns / 1_000_000; // 转换为毫秒
System.out.println("平均推理时间: " + String.format("%.2f", avgTime) + "ms");
System.out.println("吞吐量: " + String.format("%.2f", 1000 / avgTime) + " 次/秒");
}
}
}
4. 多线程推理
使用线程池进行并发推理:
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Classifications;
import ai.djl.repository.zoo.ZooModel;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.Collectors;
public class MultiThreadInferenceExample {
public static List<Classifications> predictBatch(
ZooModel<Image, Classifications> model,
List<Image> images,
int threadPoolSize) throws Exception {
// 1. 创建线程池
ExecutorService executor = Executors.newFixedThreadPool(threadPoolSize);
try {
// 2. 为每个线程创建独立的 Predictor
List<Future<Classifications>> futures = images.stream()
.map(image -> executor.submit(() -> {
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
return predictor.predict(image);
}
}))
.collect(Collectors.toList());
// 3. 收集结果
return futures.stream()
.map(future -> {
try {
return future.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
} finally {
executor.shutdown();
}
}
}
第九部分:最佳实践
1. 资源管理
使用 try-with-resources 确保资源正确释放:
// ✅ 正确:自动关闭资源
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 使用模型
Classifications result = predictor.predict(image);
}
// ❌ 错误:手动管理资源(容易忘记关闭)
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor();
// 使用后需要手动关闭
predictor.close();
model.close();
2. 错误处理
完善的错误处理机制:
import ai.djl.ModelException;
import ai.djl.translate.TranslateException;
public class ErrorHandlingExample {
public static void classifyImage(String imagePath) {
try {
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
// 使用模型...
} catch (ModelException e) {
System.err.println("模型加载失败: " + e.getMessage());
// 记录日志
logger.error("模型加载失败", e);
// 使用备用模型或返回默认结果
} catch (TranslateException e) {
System.err.println("推理失败: " + e.getMessage());
logger.error("推理失败", e);
} catch (Exception e) {
System.err.println("发生未知错误: " + e.getMessage());
logger.error("未知错误", e);
}
}
}
3. 日志记录
配置 DJL 日志:
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class LoggingExample {
private static final Logger logger = LoggerFactory.getLogger(LoggingExample.class);
public static void main(String[] args) {
// 设置 DJL 日志级别
System.setProperty("ai.djl.logging.level", "INFO");
logger.info("开始加载模型...");
try {
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
logger.info("模型加载成功");
} catch (Exception e) {
logger.error("模型加载失败", e);
}
}
}
4. 配置管理
使用配置文件管理模型参数:
import java.util.Properties;
import java.io.InputStream;
public class ConfigExample {
public static Properties loadConfig() throws Exception {
Properties props = new Properties();
try (InputStream is = ConfigExample.class
.getResourceAsStream("/djl.properties")) {
props.load(is);
}
return props;
}
public static Criteria<Image, Classifications> buildCriteriaFromConfig()
throws Exception {
Properties config = loadConfig();
return Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optEngine(config.getProperty("engine", "PyTorch"))
.optModelPath(Paths.get(config.getProperty("model.path", "./models")))
.build();
}
}
第十部分:常见问题与解决方案
问题1:找不到引擎
错误信息:
No engine found for: PyTorch
解决方案:
- 检查依赖配置:
<!-- 确保添加了引擎依赖 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.25.0</version>
</dependency>
<!-- 以及原生库 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>${os}-x86_64</classifier>
</dependency>
- 检查操作系统分类器:
<properties>
<os.detected.classifier>${os.detected.name}-${os.detected.arch}</os.detected.classifier>
</properties>
问题2:模型加载失败
解决方案:
// 1. 检查模型路径
Path modelPath = Paths.get("models/my_model");
if (!Files.exists(modelPath)) {
System.err.println("模型路径不存在: " + modelPath);
return;
}
// 2. 检查模型格式
// 确保模型格式与引擎匹配(PyTorch 模型使用 PyTorch 引擎)
// 3. 使用正确的引擎
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(modelPath)
.optEngine("PyTorch") // 明确指定引擎
.build();
问题3:内存不足
解决方案:
// 1. 减少批次大小
// 2. 使用 CPU 而不是 GPU(如果 GPU 内存不足)
Device device = Device.cpu();
// 3. 及时释放资源
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 使用后自动关闭
}
// 4. 增加 JVM 内存
// java -Xmx4g YourApplication
// 5. 使用模型量化(如果支持)
问题4:GPU 不可用
解决方案:
// 检查 GPU 可用性
Engine engine = Engine.getInstance();
if (engine.hasCapability("CUDA")) {
Device device = Device.gpu(0);
// 使用 GPU
} else {
Device device = Device.cpu();
// 使用 CPU
System.out.println("GPU 不可用,使用 CPU");
}
问题5:性能问题
优化建议:
- 模型预热: 在正式使用前运行几次推理
- 批量处理: 使用批量处理提高吞吐量
- 使用 GPU: 如果可用,使用 GPU 加速
- 模型量化: 使用量化模型减少内存和计算量
- 多线程: 使用多线程进行并发推理
第十一部分:DJL vs 其他方案
对比 ONNX Runtime
| 特性 | DJL | ONNX Runtime |
|---|---|---|
| 框架支持 | ✅ 多种框架 | ⚠️ 仅 ONNX |
| 统一 API | ✅ 是 | ⚠️ 仅 ONNX |
| Hugging Face | ✅ 直接支持 | ❌ 需转换 |
| 模型训练 | ✅ 支持 | ❌ 仅推理 |
| 易用性 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| 性能 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
选择建议:
- 如果需要多种框架支持 → 选择 DJL
- 如果只需要 ONNX 模型推理 → 选择 ONNX Runtime
对比 Deeplearning4j
| 特性 | DJL | Deeplearning4j |
|---|---|---|
| 框架支持 | ✅ 多种框架 | ❌ 仅自己的实现 |
| Hugging Face | ✅ 直接支持 | ❌ 不支持 |
| API 设计 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ |
| 易用性 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ |
| 社区支持 | ⭐⭐⭐⭐ AWS 支持 | ⭐⭐⭐ |
选择建议:
- 如果需要使用预训练模型(特别是 Hugging Face) → 选择 DJL
- 如果需要完全 Java 实现的框架 → 选择 Deeplearning4j
对比 Smile
| 特性 | DJL | Smile |
|---|---|---|
| 深度学习 | ✅ 支持 | ❌ 不支持 |
| 传统 ML | ⚠️ 有限 | ✅ 支持 |
| Hugging Face | ✅ 支持 | ❌ 不支持 |
| 易用性 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
选择建议:
- 如果需要深度学习 → 选择 DJL
- 如果只需要传统机器学习 → 选择 Smile
第十二部分:总结
DJL 的核心优势
- 框架无关:一套代码支持所有框架,灵活切换
- 统一 API:学习成本低,一套 API 掌握所有框架
- Hugging Face 集成:可以直接加载模型,无需转换
- 易于使用:Java 风格,易于集成到现有项目
- 生产就绪:企业级支持,稳定可靠
- 高性能:直接调用底层框架,性能优秀
- 跨平台:支持多种操作系统和硬件
适用场景
适合使用 DJL:
- ✅ Java 项目需要深度学习功能
- ✅ 需要加载 Hugging Face 模型
- ✅ 需要支持多种框架
- ✅ 需要统一的 API
- ✅ 企业级应用
- ✅ 微服务架构中的模型服务
不适合使用 DJL:
- ❌ 只需要传统机器学习(使用 Smile)
- ❌ 只需要 ONNX 模型推理(使用 ONNX Runtime)
- ❌ Python 项目(使用原生框架更合适)
学习路径建议
-
基础阶段:
- 了解 DJL 的核心概念(Engine、Model、Translator、Predictor)
- 完成简单的图像分类示例
- 理解资源管理(try-with-resources)
-
进阶阶段:
- 学习自定义转换器
- 掌握 Hugging Face 模型加载
- 了解 GPU 支持和性能优化
-
高级阶段:
- 自定义模型训练
- 生产环境部署
- 性能调优和故障排查
类比总结
理解 DJL,就像理解为什么使用统一接口:
- 框架无关:就像 USB-C 接口,可以连接任何设备
- 统一 API:就像标准插座,任何电器都能用
- 易于使用:就像通用遥控器,一个控制所有
- Hugging Face 集成:就像应用商店,直接下载使用
- 生产就绪:就像企业级软件,稳定可靠
掌握 DJL,你就能在 Java 生态系统中轻松使用深度学习!