AI: Android 运行ONNX模型
概述
ONNX(Open Neural Network Exchange)模型, Android ONNX Runtime 是微软开源的跨平台推理引擎 ONNX Runtime 在 Android 平台上的应用版本,主要用于在 Android 设备上高效运行机器学习模型。
实现方法 (参考信息来源Grok)
1. 使用 ONNX Runtime
ONNX Runtime 是由微软开发的高性能推理引擎,支持在 Android 平台上运行 ONNX 模型。它提供了高效的优化和跨平台支持。
2. 使用 TensorFlow Lite(转换 ONNX 模型)
TensorFlow Lite 是 Android 上常用的轻量级深度学习框架。虽然它原生不支持 ONNX 模型,但可以通过转换工具将 ONNX 模型转换为 TFLite 格式。
3. 使用 PyTorch Mobile
如果 ONNX 模型是从 PyTorch 导出的,可以考虑直接使用 PyTorch Mobile 运行模型,绕过 ONNX 格式(或在必要时转换)。
4. 使用 MNN(Mobile Neural Network)
MNN 是阿里巴巴开发的轻量级推理框架,支持 ONNX 模型,适用于 Android 平台。
5. 使用 NCNN
NCNN 是腾讯优图开发的移动端推理框架,也支持 ONNX 模型。
比较与建议
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| ONNX Runtime | 高性能、硬件加速、跨平台 | 需要学习 API | 通用、高性能推理 |
| TensorFlow Lite | 移动端优化、广泛支持 | 模型转换复杂 | 轻量级、资源受限设备 |
| PyTorch Mobile | 适合 PyTorch 模型、优化良好 | 不直接支持 ONNX | PyTorch 模型直接部署 |
| MNN | 轻量级、多格式支持 | 社区较小 | 资源受限设备、跨格式支持 |
| NCNN | 高性能、低内存占用 | 转换复杂、C++ 接口 | 高性能、低资源需求场景 |
建议:
- 如果追求简单性和高性能,ONNX Runtime 是首选,适合大多数场景。
- 如果模型复杂且需要移动端优化,考虑将 ONNX 转换为 TFLite 或 MNN。
- 如果模型来自 PyTorch,PyTorch Mobile 是更直接的选择。
- 对于极致性能和低资源占用,NCNN 是不错的选择,但需要更多开发工作。
注意事项
- 模型优化:运行前可使用 ONNX 优化工具(如
onnx-simplifier)简化模型,减少计算量。 - 硬件加速:根据设备支持,选择合适的硬件加速选项(如 NNAPI、GPU)。
- 兼容性测试:不同框架对 ONNX 算子的支持程度不同,需测试模型兼容性。
- 安全性:确保模型文件存储在安全位置,避免泄露。
尝试ONNX Runtime
实现步骤:
- 预处理输入:根据模型输入要求,将数据(例如图像或张量)转换为
OnnxTensor格式。 - 后处理输出:解析输出结果,转换为应用需要的格式。
优化:ONNX Runtime 支持硬件加速(如 NNAPI),可以在 SessionOptions 中启用:
SessionOptions options =newSessionOptions(); options.addNnapi();执行推理:使用 session.run() 方法运行模型,获取输出:
OnnxTensor inputTensor =OnnxTensor.createTensor(env, inputData);Map<String,OnnxTensor> inputs =newHashMap<>(); inputs.put("input_name", inputTensor);OrtSession.Result outputs = session.run(inputs);加载模型:将训练好的 ONNX 模型文件(例如 model.onnx)放入 Android 项目的 assets 目录或存储中,并通过 ONNX Runtime 加载:
importai.onnxruntime.OnnxTensor;importai.onnxruntime.OrtEnvironment;importai.onnxruntime.OrtSession;OrtEnvironment env =OrtEnvironment.getEnvironment();OrtSession session = env.createSession(modelPath,newOrtSession.SessionOptions());引入依赖:在 Android 项目的 build.gradle 文件中添加 ONNX Runtime 的依赖。例如:
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.18.0' 解析ONNX
在NETRON上打开ONNX文件, 可以看到如下信息:
对应在代码中获取的结果如下:
inputNames: x, h, c, inputNode:x: [1, 512], FLOAT inputNode:h: [2, 1, 64], FLOAT inputNode:c: [2, 1, 64], FLOAT outputNames: prob, new_h, new_c, outputNode: prob: [1, 1], FLOAT outputNode: new_h: [2, 1, 64], FLOAT outputNode: new_c: [2, 1, 64], FLOAT 在 ONNX Runtime 中,通过 session.getInputNames() 和 session.getOutputNames() 获取的输入和输出名称是 Set 类型,表示模型可能具有多个输入节点和多个输出节点。
模型支持多个输入节点(Multiple Input Nodes)
- 含义:ONNX 模型可以定义多个输入节点,每个节点有唯一的名称、形状和数据类型。
session.getInputNames()返回所有输入节点的名称集合。 - 应用场景:
- 多模态模型:例如,一个模型同时接受图像和文本作为输入。输入名称可能是
["image_input", "text_input"],分别对应图像张量(如[1, 3, 224, 224],FLOAT)和文本张量(如[1, 128],INT32)。 - 多分支网络:某些网络(如双塔模型)需要不同类型的输入数据(如用户特征和物品特征)。
- 控制输入:模型可能需要额外的输入(如超参数、权重调整张量)来控制推理行为。
- 多模态模型:例如,一个模型同时接受图像和文本作为输入。输入名称可能是
- 推理时:
- 因此,输入集合表示模型在一次推理中需要多种数据同时送入算法,而不是“支持多种输入类型”。
每次推理需要为所有输入节点提供数据,存储在 Map<String, OnnxTensor> 中。例如:
val inputs =mapOf("image_input"to imageTensor,"text_input"to textTensor )val result = session.run(inputs)输出集合:
- 多输出节点:类似输入,输出集合 (
session.getOutputNames()) 表示模型可能产生多个输出节点。例如:- 目标检测模型可能输出
["boxes", "scores", "labels"],分别表示边界框坐标、置信度分数和类别标签。 - 多任务学习模型可能输出分类和回归结果。
- 目标检测模型可能输出
推理时:OrtSession.Result 包含所有输出节点的张量,键是输出名称,值是 OnnxTensor。开发者可以选择处理全部或部分输出:
val outputNames = session.outputNames // 例如 ["boxes", "scores"] session.use{ it.run(inputs).use{ result ->val boxes = result.get("boxes")as OnnxTensor val scores = result.get("scores")as OnnxTensor // 处理 boxes 和 scores}}示例解读
基于 Android 平台的 ONNX 运行时的基本对象检测示例应用程序,支持 Ort-Extensions 进行预处理/后处理。该演示应用程序完成了从给定图像中检测对象的任务。此处使用的模型来自 Yolov8 扩展版本,并支持预处理/后处理。
该模型 (Yolov8n) 可以直接输入图像字节,并输出带有边界框的检测到的对象。
完整的示例代码可以参考:Object Detection Android sample
关键文件目录:
│ ├── main │ │ ├── AndroidManifest.xml │ │ ├── assets //测试图片 │ │ │ ├── test_object_detection_0.jpg │ │ │ └── test_object_detection_1.jpg │ │ ├── java │ │ │ └── ai │ │ │ └── onnxruntime │ │ │ └── example │ │ │ └── objectdetection │ │ │ ├── MainActivity.kt //主界面 │ │ │ └── ObjectDetector.kt //关键调用模型实现 │ │ └── res │ │ ├── drawable │ │ ├── raw │ │ │ ├── classes.txt //分类标签 │ │ │ └── yolov8n_with_pre_post_processing.onnx //模型文件 │ │ ├── values │ │ │ ├── colors.xml │ │ │ ├── ids.xml │ │ │ ├── strings.xml │ │ │ └── themes.xml │ │ └── xml │ │ ├── backup_rules.xml │ │ └── data_extraction_rules.xml MainActivity.kt
package ai.onnxruntime.example.objectdetection import ai.onnxruntime.*import ai.onnxruntime.extensions.OrtxPackage import android.annotation.SuppressLint import android.graphics.Bitmap import android.graphics.BitmapFactory import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import android.graphics.PorterDuff import android.graphics.PorterDuffXfermode import android.os.Bundle import android.util.Log import android.widget.Button import android.widget.ImageView import android.widget.Toast import androidx.activity.*import androidx.appcompat.app.AppCompatActivity import kotlinx.android.synthetic.main.activity_main.*import kotlinx.coroutines.*import java.io.InputStream import java.util.*class MainActivity :AppCompatActivity(){privatevar ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()privatelateinitvar ortSession: OrtSession privatelateinitvar inputImage: ImageView privatelateinitvar outputImage: ImageView privatelateinitvar objectDetectionButton: Button privatevar imageid =0;privatelateinitvar classes:List<String>@SuppressLint("UseCompatLoadingForDrawables")overridefunonCreate(savedInstanceState: Bundle?){super.onCreate(savedInstanceState)setContentView(R.layout.activity_main) inputImage =findViewById(R.id.imageView1) outputImage =findViewById(R.id.imageView2) objectDetectionButton =findViewById(R.id.object_detection_button) inputImage.setImageBitmap( BitmapFactory.decodeStream(readInputImage())); imageid =0 classes =readClasses();// Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.// Note: These are used to decode the input image into the format the original model requires,// and to encode the model output into png formatval sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions() sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())//从raw中读取模型文件进行初始化 ortSession = ortEnv.createSession(readModel(), sessionOptions) objectDetectionButton.setOnClickListener{try{//启动算法检测performObjectDetection(ortSession) Toast.makeText(baseContext,"ObjectDetection performed!", Toast.LENGTH_SHORT).show()}catch(e: Exception){ Log.e(TAG,"Exception caught when perform ObjectDetection", e) Toast.makeText(baseContext,"Failed to perform ObjectDetection", Toast.LENGTH_SHORT).show()}}}overridefunonDestroy(){super.onDestroy() ortEnv.close() ortSession.close()}privatefunupdateUI(result: Result){val mutableBitmap: Bitmap = result.outputBitmap.copy(Bitmap.Config.ARGB_8888,true)val canvas =Canvas(mutableBitmap)val paint =Paint() paint.color = Color.WHITE // Text Color paint.textSize =28f// Text Size paint.xfermode =PorterDuffXfermode(PorterDuff.Mode.SRC_OVER)// Text Overlapping Pattern canvas.drawBitmap(mutableBitmap,0.0f,0.0f, paint)var boxit = result.outputBox.iterator()while(boxit.hasNext()){var box_info = boxit.next() canvas.drawText("%s:%.2f".format(classes[box_info[5].toInt()],box_info[4]), box_info[0]-box_info[2]/2, box_info[1]-box_info[3]/2, paint)} outputImage.setImageBitmap(mutableBitmap)}privatefunreadModel(): ByteArray {val modelID = R.raw.yolov8n_with_pre_post_processing return resources.openRawResource(modelID).readBytes()}privatefunreadClasses(): List<String>{return resources.openRawResource(R.raw.classes).bufferedReader().readLines()}privatefunreadInputImage(): InputStream { imageid = imageid.xor(1)return assets.open("test_object_detection_${imageid}.jpg")}//调用算法并读取解析结果, 最后更新UIprivatefunperformObjectDetection(ortSession: OrtSession){var objDetector =ObjectDetector()var imagestream =readInputImage() inputImage.setImageBitmap( BitmapFactory.decodeStream(imagestream)); imagestream.reset()var result = objDetector.detect(imagestream, ortEnv, ortSession)updateUI(result);}companionobject{constval TAG ="ORTObjectDetection"}}ObjectDetector.kt 调用ONXX模型.
package ai.onnxruntime.example.objectdetection import ai.onnxruntime.OnnxJavaType import ai.onnxruntime.OrtSession import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import android.graphics.Bitmap import android.graphics.BitmapFactory import java.io.InputStream import java.nio.ByteBuffer import java.util.*internaldataclassResult(var outputBitmap: Bitmap,var outputBox: Array<FloatArray>){}internalclassObjectDetector(){fundetect(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {// Step 1: convert image into byte array (raw image bytes)val rawImageBytes = inputStream.readBytes()// Step 2: get the shape of the byte array and make ort tensorval shape =longArrayOf(rawImageBytes.size.toLong())val inputTensor = OnnxTensor.createTensor( ortEnv, ByteBuffer.wrap(rawImageBytes), shape, OnnxJavaType.UINT8 ) inputTensor.use{// Step 3: call ort inferenceSession runval output = ortSession.run(Collections.singletonMap("image", inputTensor),setOf("image_out","scaled_box_out_next"))// Step 4: output analysis output.use{val rawOutput =(output?.get(0)?.value)as ByteArray val boxOutput =(output?.get(1)?.value)as Array<FloatArray>val outputImageBitmap =byteArrayToBitmap(rawOutput)// Step 5: set output resultvar result =Result(outputImageBitmap,boxOutput)return result }}}privatefunbyteArrayToBitmap(data: ByteArray): Bitmap {return BitmapFactory.decodeByteArray(data,0,data.size)}}执行效果:


扩展: ASR, 本地语音听写的实现(SherpaOnnxVadAsr)
sherpa-onnx
Android build
Android build 2
sherpa-onnx Android 源码和测试apk
asr-models
步骤:
- 下载源码
- 配置 SDK and NDK
- 调用 build-android-arm64-v8a.sh 进行编译, 然而编译失败 转而下载已发布的aardownload sherpa-onnx.aar
- 下载模型文件 modelsherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
SherpaOnnxVadAsr最终源码结构如下:
├── main │ ├── AndroidManifest.xml │ ├── assets │ │ ├── sherpa-onnx-paraformer-zh-2023-09-14 │ │ │ ├── model.int8.onnx │ │ │ └── tokens.txt │ │ └── silero_vad.onnx │ ├── java │ │ └── com │ │ └── k2fsa │ │ └── sherpa │ │ └── onnx │ │ ├── FeatureConfig.kt ->../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt │ │ ├── HomophoneReplacerConfig.kt ->../../../../../../../../../../sherpa-onnx/kotlin-api/HomophoneReplacerConfig.kt │ │ ├── MainActivity.kt │ │ ├── OfflineRecognizer.kt ->../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt │ │ ├── OfflineStream.kt ->../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt │ │ └── Vad.kt ->../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt │ ├── jniLibs │ │ ├── arm64-v8a │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ ├── armeabi-v7a │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ ├── x86 │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ └── x86_64 │ │ ├── libonnxruntime4j_jni.so │ │ ├── libonnxruntime.so │ │ ├── libsherpa-onnx-c-api.so │ │ ├── libsherpa-onnx-cxx-api.so │ │ └── libsherpa-onnx-jni.so │ └── res │ ├── drawable │ │ └── ic_launcher_background.xml │ ├── drawable-v24 │ │ └── ic_launcher_foreground.xml │ ├── layout │ │ └── activity_main.xml │ ├── mipmap-anydpi-v26 │ │ ├── ic_launcher_round.xml │ │ └── ic_launcher.xml │ ├── mipmap-hdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-mdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xxhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xxxhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── values │ │ ├── colors.xml │ │ ├── strings.xml │ │ └── themes.xml │ ├── values-night │ │ └── themes.xml │ └── xml │ ├── backup_rules.xml │ └── data_extraction_rules.xml ## 参考 1. [android onnx](https://blog.51cto.com/u_16213465/13067353) 2. [Get started with ONNX Runtime Mobile](https://onnxruntime.ai/docs/get-started/with-mobile.html) 3. [ONNX Runtime](https://github.com/microsoft/onnxruntime) 4. [Object Detection Android sample](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/mobile/examples/object_detection/android)