续我的6月15号的博客,原本准备将我自己训练的YOLO V3的模型移植到手机上,但是尝试了几次都不成功,发现自己训练的模型,在转换成.pb文件之后,创建tensorflow接口总是失败,估计时我模型保存时有其他的问题,故想先移植一个官方demo能够运行的SSD模型。 话不多说,先上我的手机最终显示的效果图。 我前期主要参考https://www.voidcn.com/article/p-rbnqjtim-brt.html,这篇中的第2 和第3部分 把训练好的pb文件放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。 将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下,这个依赖文件我会放到我的资源里,可以直接下载。 在defaultConfig中添加 增加sourceSets 添加完之后的截图如下: 新建接口和新建类一样的,这里不重复。 创建的这个类主要的功能是:1.创建一个模型的接口,2.将需要识别的图片传入模型 3.将识别的结果从模型中取出,并返回最终结果。 直接继承这个Classifier接口,可能会报错,只需要在报错的地方,点击显示的红色的小灯泡,然后就可以继承这个接口了。 将图片缩放至指定的大小:bitmap即为你想要输入模型的图片 其中bitmapResize的函数如下: 计算原图和送入模型的图像缩放比:scaleimageX和scaleimageY的类型为float 首先设置画布和画笔的参数,然后计算模型识别结果到原图的映射,最终画出目标检测结果边界框、类别和概率。代码如下: 到这里,调用模型的步骤就结束了,我主要是阅读tensorflow中的Android的demo,从里面抽取出我需要的功能,最终成功了。 耗时6天,完成了每天除了吃饭睡觉一直在干的事情。从刚开始的一头雾水不知道从何下手,最终完成了我所需要的模型调用功能。这次最大的收获是,知道了Android端调用深度学习模型的几个步骤,相当于我毕设的倒数第二章已经完成,下一步是查找正确的保存模型并正确的转换成.pb文件的方法,将自己训练的粮虫识别的模型移植到手机上。目录
2.1 将想要移植的模型放到指定位置
我这个步骤主要是从tensorflow的官方demo中将ssd_mobilenet_v1_android_export.pb和coco_labels_list.txt copy过来放到了assets文件夹下面2.2 添加.so和 .jar的依赖
2.3 appbuild.gradle(Module:app)配置
multiDexEnabled true ndk { abiFilters "armeabi-v7a" }
sourceSets { main { jniLibs.srcDirs = ['libs'] } }
在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:implementation files('libs/libandroid_tensorflow_inference_java.jar')
这个类是从官方demo中直接抄的,没有做任何的修改,所以直接复制粘贴就行了。package com.example.mycamera; import android.graphics.Bitmap; import android.graphics.RectF; import java.util.List; public interface Classifier { /** * An immutable result returned by a Classifier describing what was recognized. */ public class Recognition { /** * A unique identifier for what has been recognized. Specific to the class, not the instance of * the object. */ private final String id; /** * Display name for the recognition. */ private final String title; /** * A sortable score for how good the recognition is relative to others. Higher should be better. */ private final Float confidence; /** Optional location within the source image for the location of the recognized object. */ private RectF location; public Recognition( final String id, final String title, final Float confidence, final RectF location) { this.id = id; this.title = title; this.confidence = confidence; this.location = location; } public String getId() { return id; } public String getTitle() { return title; } public Float getConfidence() { return confidence; } public RectF getLocation() { return new RectF(location); } public void setLocation(RectF location) { this.location = location; } @Override public String toString() { String resultString = ""; if (id != null) { resultString += "[" + id + "] "; } if (title != null) { resultString += title + " "; } if (confidence != null) { resultString += String.format("(%.1f%%) ", confidence * 100.0f); } if (location != null) { resultString += location + " "; } return resultString.trim(); } } List<Recognition> recognizeImage(Bitmap bitmap); void enableStatLogging(final boolean debug); String getStatString(); void close(); }
//继承Classifier类的create功能 public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, final int inputSize) throws IOException { final TFYoloV3Detector d = new TFYoloV3Detector(); InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { //LOGGER.w(line); d.labels.add(line); } br.close(); d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); final Graph g = d.inferenceInterface.graph(); d.inputName = "image_tensor"; final Operation inputOp = g.operation(d.inputName); if (inputOp == null) { throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); } d.inputSize = inputSize; final Operation outputOp1 = g.operation("detection_scores"); if (outputOp1 == null) { throw new RuntimeException("Failed to find output Node 'detection_scores'"); } final Operation outputOp2 = g.operation("detection_boxes"); if (outputOp2 == null) { throw new RuntimeException("Failed to find output Node 'detection_boxes'"); } final Operation outputOp3 = g.operation("detection_classes"); if (outputOp3 == null) { throw new RuntimeException("Failed to find output Node 'detection_classes'"); } // Pre-allocate buffers. d.outputNames = new String[] {"detection_boxes", "detection_scores", "detection_classes", "num_detections"}; d.intValues = new int[d.inputSize * d.inputSize]; d.byteValues = new byte[d.inputSize * d.inputSize * 3]; d.outputScores = new float[MAX_RESULTS]; d.outputLocations = new float[MAX_RESULTS * 4]; d.outputClasses = new float[MAX_RESULTS]; d.outputNumDetections = new float[1]; return d; }
3.3 调用模型,并传出识别结果
public List<Recognition> recognizeImage(final Bitmap bitmap) { //Bitmap bitmapResized = bitmapToFloatArray(bitmap,inputSize,inputSize);//需要将图片缩放带28*28 // Copy the input data into TensorFlow. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); } //将需要识别的图片feed给模型 inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); inferenceInterface.run(outputNames, logStats);//运行模型 outputLocations = new float[MAX_RESULTS * 4]; outputScores = new float[MAX_RESULTS]; outputClasses = new float[MAX_RESULTS]; outputNumDetections = new float[1]; //将识别的结果取出来 inferenceInterface.fetch(outputNames[0], outputLocations); inferenceInterface.fetch(outputNames[1], outputScores); inferenceInterface.fetch(outputNames[2], outputClasses); inferenceInterface.fetch(outputNames[3], outputNumDetections); // Scale them back to the input size. final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); for (int i = 0; i < (int)outputNumDetections[0]; ++i) { final RectF detection = new RectF( outputLocations[4 * i + 1] * inputSize, outputLocations[4 * i] * inputSize, outputLocations[4 * i + 3] * inputSize, outputLocations[4 * i + 2] * inputSize); recognitions.add( new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection)); } /*final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); for (int i = 0; i <= Math.min(pq.size(), MAX_RESULTS); ++i) { recognitions.add(pq.poll()); }*/ return recognitions; }
4.1 图片传入模型前的处理
Bitmap bitmapResized = bitmapResize(bitmap,YOLO_INPUT_SIZE,YOLO_INPUT_SIZE);//需要将图片缩放至416*416
//将原图缩放到模型的指定输入大小,bitmap是原图,rx,ry是模型的输入图片大小 public static Bitmap bitmapResize(Bitmap bitmap, int rx, int ry){ int height = bitmap.getHeight(); int width = bitmap.getWidth(); // 计算缩放比例 float scaleWidth = ((float) rx) / width; float scaleHeight = ((float) ry) / height; Matrix matrix = new Matrix(); matrix.postScale(scaleWidth, scaleHeight); bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true); return bitmap; }
scaleimageX=(float) (bitmap.getWidth()*1.0)/bitmapResized.getWidth();//计算原图和送入模型的缩放比例x方向 scaleimageY=(float)(bitmap.getHeight()*1.0)/bitmapResized.getHeight();//计算原图和送入模型的缩放比例x方向
4.2 取出模型的识别结果
final List<Classifier.Recognition> results = detector.recognizeImage(bitmapResized);//取出识别的结果
4.3 在原图上画出识别结果
croppedBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true);//copy原图 final Canvas canvas = new Canvas(croppedBitmap);//创建一个新画布 final Paint paint = new Paint();//创建绘制 paint.setColor(Color.RED);//设置颜色 paint.setStyle(Paint.Style.STROKE);//创建绘制轮廓 paint.setStrokeWidth(5.0f);//设置画笔的宽度 final Paint paintText = new Paint();//创建字体 paintText.setColor(Color.RED);//设置颜色 paintText.setTextSize(80);//设置子图大小 float minimumConfidence = MINIMUM_CONFIDENCE_YOLO; final List<Classifier.Recognition> mappedRecognitions = new LinkedList<Classifier.Recognition>(); for (final Classifier.Recognition result : results) { //还原边界框在原图的位置 final RectF location = new RectF( result.getLocation().left *= scaleimageX, result.getLocation().top *= scaleimageY, result.getLocation().right *= scaleimageX, result.getLocation().bottom *= scaleimageY); //判断大于设置的置信度则将位置在原图上标记出来 if (location != null && result.getConfidence() >= minimumConfidence) { //location[0]=location[0]* canvas.drawRect(location, paint);//画边界框 canvas.drawText(result.getTitle()+" "+result.getConfidence(), location.left, location.top-10, paintText);//将类别和概率显示在图上 //cropToFrameTransform.mapRect(location); result.setLocation(location); mappedRecognitions.add(result); } } cameraPicture.setImageBitmap(croppedBitmap);//将最终的图显示在ImageVieView控件
