|
| 1 | +package com.fpf.smartscansdk.ml.models.providers.detectors.face |
| 2 | + |
| 3 | +import android.content.Context |
| 4 | +import android.graphics.Bitmap |
| 5 | +import android.util.Log |
| 6 | +import androidx.core.graphics.scale |
| 7 | +import com.fpf.smartscansdk.core.embeddings.IDetectorProvider |
| 8 | +import com.fpf.smartscansdk.core.media.nms |
| 9 | +import com.fpf.smartscansdk.ml.data.FilePath |
| 10 | +import com.fpf.smartscansdk.ml.data.ModelSource |
| 11 | +import com.fpf.smartscansdk.ml.data.ResourceId |
| 12 | +import com.fpf.smartscansdk.ml.data.TensorData |
| 13 | +import com.fpf.smartscansdk.ml.models.FileOnnxLoader |
| 14 | +import com.fpf.smartscansdk.ml.models.OnnxModel |
| 15 | +import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader |
| 16 | +import kotlinx.coroutines.Dispatchers |
| 17 | +import kotlinx.coroutines.withContext |
| 18 | +import java.nio.ByteBuffer |
| 19 | +import java.nio.ByteOrder |
| 20 | +import java.nio.FloatBuffer |
| 21 | + |
| 22 | + |
| 23 | +class FaceDetector( |
| 24 | + context: Context, |
| 25 | + modelSource: ModelSource, |
| 26 | + private val confThreshold: Float = 0.5f, |
| 27 | + private val nmsThreshold: Float = 0.3f |
| 28 | +) : IDetectorProvider<Bitmap> { |
| 29 | + private val model: OnnxModel = when(modelSource){ |
| 30 | + is FilePath -> OnnxModel(FileOnnxLoader(modelSource.path)) |
| 31 | + is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId)) |
| 32 | + } |
| 33 | + |
| 34 | + companion object { |
| 35 | + private const val TAG = "FaceDetector" |
| 36 | + const val DIM_BATCH_SIZE = 1 |
| 37 | + const val DIM_PIXEL_SIZE = 3 |
| 38 | + const val IMAGE_SIZE_X = 320 |
| 39 | + const val IMAGE_SIZE_Y = 240 |
| 40 | + } |
| 41 | + |
| 42 | + override suspend fun initialize() = model.loadModel() |
| 43 | + |
| 44 | + override fun isInitialized() = model.isLoaded() |
| 45 | + |
| 46 | + private var closed = false |
| 47 | + |
| 48 | + override suspend fun detect(data: Bitmap): Pair<List<Float>, List<FloatArray>> = withContext(Dispatchers.Default) { |
| 49 | + val startTime = System.currentTimeMillis() |
| 50 | + val inputShape = longArrayOf(DIM_BATCH_SIZE.toLong(), DIM_PIXEL_SIZE.toLong(), IMAGE_SIZE_Y.toLong(), IMAGE_SIZE_X.toLong()) |
| 51 | + val imgData: FloatBuffer = preProcess(data) |
| 52 | + val inputName = model.getInputNames()?.firstOrNull() ?: throw IllegalStateException("Model inputs not available") |
| 53 | + val outputs = model.run(mapOf(inputName to TensorData.FloatBufferTensor(imgData, inputShape))) |
| 54 | + |
| 55 | + val outputList = outputs.values.toList() |
| 56 | + @Suppress("UNCHECKED_CAST") |
| 57 | + val scoresRawFull = outputList[0] as Array<Array<FloatArray>> |
| 58 | + @Suppress("UNCHECKED_CAST") |
| 59 | + val boxesRawFull = outputList[1] as Array<Array<FloatArray>> |
| 60 | + |
| 61 | + // Extract the first element (batch dimension) |
| 62 | + val scoresRaw = scoresRawFull[0] // shape: [num_boxes, 2] |
| 63 | + val boxesRaw = boxesRawFull[0] // shape: [num_boxes, 4] |
| 64 | + |
| 65 | + val imgWidth = data.width |
| 66 | + val imgHeight = data.height |
| 67 | + |
| 68 | + val boxesList = mutableListOf<FloatArray>() |
| 69 | + val scoresList = mutableListOf<Float>() |
| 70 | + for (i in scoresRaw.indices) { |
| 71 | + val faceScore = scoresRaw[i][1] |
| 72 | + if (faceScore > confThreshold) { |
| 73 | + val box = boxesRaw[i] |
| 74 | + // Box values are normalized; convert to absolute pixel coordinates. |
| 75 | + val x1 = box[0] * imgWidth |
| 76 | + val y1 = box[1] * imgHeight |
| 77 | + val x2 = box[2] * imgWidth |
| 78 | + val y2 = box[3] * imgHeight |
| 79 | + boxesList.add(floatArrayOf(x1, y1, x2, y2)) |
| 80 | + scoresList.add(faceScore) |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + val inferenceTime = System.currentTimeMillis() - startTime |
| 85 | + Log.d(TAG, "Detection Inference Time: $inferenceTime ms") |
| 86 | + |
| 87 | + // Apply NMS if any detection exists. |
| 88 | + if (boxesList.isNotEmpty()) { |
| 89 | + val keepIndices = nms(boxesList, scoresList, nmsThreshold) |
| 90 | + val filteredBoxes = keepIndices.map { boxesList[it] } |
| 91 | + val filteredScores = keepIndices.map { scoresList[it] } |
| 92 | + return@withContext Pair(filteredScores, filteredBoxes) |
| 93 | + } else { |
| 94 | + return@withContext Pair(emptyList<Float>(), emptyList<FloatArray>()) |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + private fun preProcess(bitmap: Bitmap): FloatBuffer { |
| 99 | + val resizedBitmap = bitmap.scale(IMAGE_SIZE_X, IMAGE_SIZE_Y) |
| 100 | + val width = resizedBitmap.width |
| 101 | + val height = resizedBitmap.height |
| 102 | + val intValues = IntArray(width * height) |
| 103 | + resizedBitmap.getPixels(intValues, 0, width, 0, 0, width, height) |
| 104 | + |
| 105 | + val floatArray = FloatArray(DIM_PIXEL_SIZE * height * width) |
| 106 | + |
| 107 | + // Process each pixel and store them in channel-first order. |
| 108 | + // Channel 0: indices 0 .. height*width-1, etc. |
| 109 | + for (i in 0 until height) { |
| 110 | + for (j in 0 until width) { |
| 111 | + val pixel = intValues[i * width + j] |
| 112 | + val r = ((pixel shr 16) and 0xFF).toFloat() |
| 113 | + val g = ((pixel shr 8) and 0xFF).toFloat() |
| 114 | + val b = (pixel and 0xFF).toFloat() |
| 115 | + |
| 116 | + // Normalize channels |
| 117 | + val normalizedR = (r - 127f) / 128f |
| 118 | + val normalizedG = (g - 127f) / 128f |
| 119 | + val normalizedB = (b - 127f) / 128f |
| 120 | + |
| 121 | + val index = i * width + j |
| 122 | + floatArray[index] = normalizedR |
| 123 | + floatArray[height * width + index] = normalizedG |
| 124 | + floatArray[2 * height * width + index] = normalizedB |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + val byteBuffer = ByteBuffer.allocateDirect(floatArray.size * 4).order(ByteOrder.nativeOrder()) |
| 129 | + val floatBuffer = byteBuffer.asFloatBuffer() |
| 130 | + floatBuffer.put(floatArray) |
| 131 | + floatBuffer.position(0) |
| 132 | + return floatBuffer |
| 133 | + } |
| 134 | + |
| 135 | + override fun closeSession() { |
| 136 | + if (closed) return |
| 137 | + closed = true |
| 138 | + (model as? AutoCloseable)?.close() |
| 139 | + } |
| 140 | +} |
0 commit comments