您当前的位置: 首页 >  tensorflow

寒冰屋

暂无认证

  • 0浏览

    0关注

    2286博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

在Android上加载具有AI危害检测的TensorFlow模型

寒冰屋 发布时间:2021-01-24 21:46:40 ,浏览量:0

目录

格式化模型数据

测试模型

  • 下载源-53.8 MB

在本系列的上一篇文章中,我们创建了一个项目,该项目将用于驾驶员的实时危害检测,并准备了一个用于TensorFlow Lite的检测模型。在这里,我们将继续加载模型并为图像处理做准备。

要将模型添加到项目中,请在src/main中创建一个名为asset的新文件夹。将TensorFlow Lite模型和包含标签的文本文件复制到src/main/assets,使其成为项目的一部分。

要使用该模型,我们必须编写代码以加载该模型并通过它传递数据。检测代码将放置在两个用户界面可以共享的类中,以便可以在静态图像(用于测试)和实时视频流上使用相同的代码。

格式化模型数据

在开始为此编写代码之前,我们需要知道模型如何期望其输入数据被结构化。数据作为多维数组传入和传出。这也称为数据的“形状”。通常,当您找到模型时,将记录此信息。

您也可以使用Netron工具检查数据。使用此工具打开模型时,将显示组成网络的节点。单击输入节点(显示在图的顶部)将显示输入数据(在这种情况下为图像)的信息格式以及网络的输出。在这种情况下,我们看到输入数据是一个32位浮点数的数组。阵列的尺寸为1x416x416x3。这意味着网络将一次以416 x 416像素的速度接收一个包含红色、绿色和蓝色分量的图像。如果要为此项目使用其他模型,则需要检查模型的输入和输出,并相应地调整代码。解释结果时,我们将更详细地检查输出数据。

将新类添加到名为Detector的项目中。用于管理受训网络的所有代码都将添加到此类中。构建该类时,它将接受图像并以易于使用的格式提供结果。我们应该在类中添加一些常量和字段以开始使用它。这些字段包括一个TensorFlow Interpreter对象(包含受过训练的网络),该模型可识别的对象类别列表以及应用程序上下文。

class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap? = null
}

此类的构造函数将创建输出缓冲区,加载网络模型,并从资产文件夹加载对象类的名称。

class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap? = null
}
测试模型

只需几行代码即可执行网络模型。当将图像提供给Detector类时,将调整其大小以匹配网络的要求。Bitmap图像中的数据被编码为字节。该值必须转换为32位浮点值。TensorFlow Lite库包含使此类通用转换变得容易的功能。该TensorImage类型还具有一种方便的方法,允许将其用作需要输入缓冲区的方法的缓冲区。

public fun processImage(sourceImage: Bitmap) {
   val imageProcessor = ImageProcessor.Builder()
           .add(ResizeOp(IMAGE_HEIGHT, IMAGE_WIDTH, ResizeOp.ResizeMethod.BILINEAR))
           .build()
   var tImage = TensorImage(DataType.FLOAT32)
   tImage.load(sourceImage)
   tImage = imageProcessor.process(tImage)
   tfLiteInterpreter.runForMultipleInputsOutputs(arrayOf(tImage.buffer), outputBuffers!!)
}

要对此进行测试,请向项目中添加新的布局。布局将具有一个简单的界面,以允许从设备中选择图像。所选图像将由检测器处理。



   
   

该活动的代码将打开系统映像选择器。选择图像并将其传递回应用程序后,它将图像传递给检测器。

public override fun onActivityResult(reqCode: Int, resultCode: Int, data: Intent?) {
   super.onActivityResult(reqCode, resultCode, data)
   if (resultCode == RESULT_OK) {
       if (reqCode == SELECT_PICTURE) {
           val selectedUri = data!!.data
           val fileString = selectedUri!!.path
           selected_image_view!!.setImageURI(selectedUri)
           var sourceBitmap: Bitmap? = null
           try {
               sourceBitmap =
                   MediaStore.Images.Media.getBitmap(this.contentResolver, selectedUri)
               RunDetector(sourceBitmap)
           } catch (e: IOException) {
               e.printStackTrace()
           }
       }
   }
}

fun RunDetector(bitmap: Bitmap?) {
   if (detector == null) detector = Detector(this)
   detector!!.processImage(bitmap!!)
}

 

UI布局的结果

现在我们可以选择一个图像,检测器将处理该图像,识别其中的物体。但是结果是什么意思呢?我们如何使用这些结果来警告用户危险?在本系列的下一篇文章中,我们将解释结果并将相关信息提供给用户。

https://www.codeproject.com/Articles/5291389/Loading-a-TensorFlow-Model-with-AI-Hazard-Detectio

关注
打赏
1665926880
查看更多评论
立即登录/注册

微信扫码登录

0.0435s