基于 DJL 与 PaddlePaddle 的菜品图像识别实现
本文介绍如何使用 Deep Java Library (DJL) 结合百度飞桨 (PaddlePaddle) 引擎,在 Java 环境中构建一个菜品图像分类系统。通过封装推理逻辑、加载预训练模型以及处理 Softmax 归一化输出,实现对上传图像菜品的自动识别。
一、技术架构与原理
1. 核心组件
- Deep Java Library (DJL): 一个跨深度学习框架的抽象层,支持 PyTorch、TensorFlow、MXNet 和 PaddlePaddle 等后端。它简化了模型加载、推理和部署的流程。
- PaddlePaddle: 百度开源的深度学习平台,提供高性能的推理引擎。在本方案中,我们使用 DJL 的 PaddlePaddle 引擎来调用飞桨模型。
- ResNet50: 本示例使用的预训练模型架构,适用于图像分类任务,具有较高的准确率。
2. Softmax 归一化说明
Softmax 函数将神经网络的原始输出(Logits)转换为概率分布。在多分类问题中,它确保所有类别的概率之和为 1,且每个概率值在 (0, 1) 之间。
计算公式如下:
$$\sigma(z)j = \frac{e^{z_j}}{\sum{k=1}^{K} e^{z_k}}$$
其中 $z_j$ 是输入向量的第 $j$ 个元素,$K$ 是类别总数。
特点与应用:
- 概率输出:直接反映模型对各类别的置信度。
- 增强差异:指数映射拉大数值间的差距,使预测结果更明确。
- 广泛应用:常用于神经网络输出层,配合交叉熵损失函数进行多分类训练。
二、环境准备与依赖配置
1. 开发环境要求
- JDK 8 或更高版本
- Maven 3.6+
- 操作系统:Linux, macOS 或 Windows (需安装相应 C++ 运行时库)
2. Maven 依赖配置
在 pom.xml 中添加以下依赖,引入 DJL API、基础数据集、模型库以及 PaddlePaddle 引擎。
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.dish</groupId>
<artifactId>dish_identification</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<djl.version>0.17.0</djl.version>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.paddlepaddle</groupId>
<artifactId>paddlepaddle-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.paddlepaddle</groupId>
<artifactId>paddlepaddle-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.17.2</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
</dependencies>
</project>
三、核心代码实现
1. 定义模型推理工具类
创建 DishesClassification 类,封装模型加载、预测及后处理逻辑。该类负责管理 Predictor 生命周期并执行 Softmax 计算。
package com.dish.utils;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class DishesClassification {
private static final Logger logger = LoggerFactory.getLogger(DishesClassification.class);
private DishesClassification() {}
public static Classifications predict(Image img) throws IOException, ModelException, TranslateException {
Classifications classifications = classifier(img);
List<Classifications.Classification> items = classifications.items();
double[] probArr = [items.size()];
;
(Classifications.Classification item : items) {
item.getProbability();
probArr[items.indexOf(item)] = prob;
(prob > max) max = prob;
}
;
( ; i < probArr.length; i++) {
probArr[i] = Math.exp(probArr[i] - max);
sum += probArr[i];
}
List<String> names = <>();
List<Double> probs = <>();
( ; i < items.size(); i++) {
Classifications. items.get(i);
names.add(item.getClassName());
probs.add(probArr[i] / sum);
}
(names, probs);
}
Classifications IOException, ModelException, TranslateException {
Criteria<Image, Classifications> criteria = Criteria.builder()
.optEngine()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get())
.optModelName()
.optTranslator( ())
.optProgress( ())
.build();
( ModelZoo.loadModel(criteria)) {
(Predictor<Image, Classifications> predictor = model.newPredictor()) {
predictor.predict(img);
}
}
}
}
2. 翻译器实现 (DishTranslator)
虽然示例中未展示完整代码,但实际运行需要实现 Translator 接口以处理图像预处理(如调整大小、归一化)。通常 DJL 的 ImageClassifierTranslator 可复用。
3. 测试入口
编写 main 方法加载本地图片并打印识别结果。
package com.dish;
import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.translate.TranslateException;
import com.dish.utils.DishesClassification;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Paths;
public final class Main {
private static final Logger logger = LoggerFactory.getLogger(Main.class);
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String imagePath = "src/main/resources/images/test_dish.jpg";
Image image = ImageFactory.getInstance().fromFile(imagePath);
Classifications classifications = DishesClassification.predict(image);
Classifications.Classification bestItem = classifications.best();
System.out.println("识别结果:" + bestItem.getClassName() + " | 概率:" + bestItem.getProbability());
logger.info(, classifications);
}
}
四、模型资源准备
为了运行上述代码,您需要下载对应的模型文件。
- 访问 DJL Model Zoo 或百度 AI Studio。
- 搜索菜品分类相关模型(如 ResNet50 训练的 dishes 模型)。
- 将下载的
.zip 模型文件解压至项目根目录下的 models 文件夹中。
- 确保文件名与代码中的
optModelName("inference") 一致。
五、常见问题排查
- 模型加载失败:检查
models/dishes.zip 路径是否正确,确保文件未损坏。
- 内存溢出:如果处理高分辨率图片,建议在
ImageFactory 中设置缩放比例,减少显存占用。
- 依赖冲突:DJL 不同引擎可能依赖不同的底层库,确保
pom.xml 中仅引入当前使用的引擎依赖。
- 日志级别:默认日志可能包含大量调试信息,可通过
log4j.properties 调整级别为 WARN 或 ERROR。
六、总结
本方案展示了如何在 Java 生态中利用 DJL 快速集成 PaddlePaddle 模型。通过模块化设计,将模型加载与业务逻辑分离,便于后续扩展其他分类任务。在实际生产环境中,建议将模型部署到服务器端,通过 REST API 提供服务,并结合容器化技术进行资源调度。
注:文中涉及的技术细节基于 DJL 0.17.0 版本,具体 API 可能随版本更新有所变化,请以官方文档为准。