Skip to content

Commit 770791a

Browse files
Merge pull request #22 from smartscanapp/update/face
Add face detector
2 parents 1c4dd04 + 25e4a41 commit 770791a

File tree

4 files changed

+319
-0
lines changed

4 files changed

+319
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.fpf.smartscansdk.core.embeddings
2+
3+
interface IDetectorProvider<T> {
4+
fun closeSession() = Unit
5+
suspend fun initialize()
6+
fun isInitialized(): Boolean
7+
suspend fun detect(data: T): Pair<List<Float>, List<FloatArray>>
8+
}

core/src/main/java/com/fpf/smartscansdk/core/media/ImageUtils.kt

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import android.content.Context
44
import android.graphics.*
55
import android.net.Uri
66
import androidx.core.graphics.scale
7+
import kotlin.math.max
8+
import kotlin.math.min
79

810
fun centerCrop(bitmap: Bitmap, imageSize: Int): Bitmap {
911
val cropX: Int
@@ -45,3 +47,75 @@ fun getBitmapFromUri(context: Context, uri: Uri, maxSize: Int): Bitmap {
4547
decoder.setTargetSize(w, h)
4648
}.copy(Bitmap.Config.ARGB_8888, true)
4749
}
50+
51+
fun cropFaces(bitmap: Bitmap, boxes: List<FloatArray>): List<Bitmap> {
52+
val faces = mutableListOf<Bitmap>()
53+
for (box in boxes) {
54+
val x1 = max(0, box[0].toInt())
55+
val y1 = max(0, box[1].toInt())
56+
val x2 = min(bitmap.width, box[2].toInt())
57+
val y2 = min(bitmap.height, box[3].toInt())
58+
val width = x2 - x1
59+
val height = y2 - y1
60+
if (width > 0 && height > 0) {
61+
val faceBitmap = Bitmap.createBitmap(bitmap, x1, y1, width, height)
62+
faces.add(faceBitmap)
63+
}
64+
}
65+
return faces
66+
}
67+
68+
69+
fun nms(boxes: List<FloatArray>, scores: List<Float>, iouThreshold: Float): List<Int> {
70+
if (boxes.isEmpty()) return emptyList()
71+
72+
val indices = scores.indices.sortedByDescending { scores[it] }.toMutableList()
73+
val keep = mutableListOf<Int>()
74+
75+
while (indices.isNotEmpty()) {
76+
val current = indices.removeAt(0)
77+
keep.add(current)
78+
val currentBox = boxes[current]
79+
80+
indices.removeAll { idx ->
81+
val iou = computeIoU(currentBox, boxes[idx])
82+
iou > iouThreshold
83+
}
84+
}
85+
return keep
86+
}
87+
88+
private fun computeIoU(boxA: FloatArray, boxB: FloatArray): Float {
89+
val x1 = max(boxA[0], boxB[0])
90+
val y1 = max(boxA[1], boxB[1])
91+
val x2 = min(boxA[2], boxB[2])
92+
val y2 = min(boxA[3], boxB[3])
93+
val intersectionArea = max(0f, x2 - x1) * max(0f, y2 - y1)
94+
val areaA = max(0f, boxA[2] - boxA[0]) * max(0f, boxA[3] - boxA[1])
95+
val areaB = max(0f, boxB[2] - boxB[0]) * max(0f, boxB[3] - boxB[1])
96+
val unionArea = areaA + areaB - intersectionArea
97+
return if (unionArea <= 0f) 0f else intersectionArea / unionArea
98+
}
99+
100+
fun drawBoxes(bitmap: Bitmap, boxes: List<FloatArray>, color: Int, margin: Int = 0, strokeWidth: Float = 2f): Bitmap {
101+
val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
102+
val canvas = Canvas(mutableBitmap)
103+
104+
val paint = Paint().apply {
105+
this.color = color
106+
this.strokeWidth = strokeWidth
107+
this.style = Paint.Style.STROKE
108+
}
109+
110+
for (box in boxes) {
111+
val x1 = max(0, box[0].toInt() -margin)
112+
val y1 = max(0, box[1].toInt() -margin)
113+
val x2 = min(mutableBitmap.width, box[2].toInt() + margin)
114+
val y2 = min(mutableBitmap.height, box[3].toInt() + margin)
115+
116+
canvas.drawRect(x1.toFloat(), y1.toFloat(), x2.toFloat(), y2.toFloat(), paint)
117+
}
118+
119+
return mutableBitmap
120+
}
121+
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package com.fpf.smartscansdk.ml.models.providers.detectors.face
2+
3+
import android.content.Context
4+
import android.graphics.Bitmap
5+
import android.util.Log
6+
import androidx.core.graphics.scale
7+
import com.fpf.smartscansdk.core.embeddings.IDetectorProvider
8+
import com.fpf.smartscansdk.core.media.nms
9+
import com.fpf.smartscansdk.ml.data.FilePath
10+
import com.fpf.smartscansdk.ml.data.ModelSource
11+
import com.fpf.smartscansdk.ml.data.ResourceId
12+
import com.fpf.smartscansdk.ml.data.TensorData
13+
import com.fpf.smartscansdk.ml.models.FileOnnxLoader
14+
import com.fpf.smartscansdk.ml.models.OnnxModel
15+
import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader
16+
import kotlinx.coroutines.Dispatchers
17+
import kotlinx.coroutines.withContext
18+
import java.nio.ByteBuffer
19+
import java.nio.ByteOrder
20+
import java.nio.FloatBuffer
21+
22+
23+
class FaceDetector(
24+
context: Context,
25+
modelSource: ModelSource,
26+
private val confThreshold: Float = 0.5f,
27+
private val nmsThreshold: Float = 0.3f
28+
) : IDetectorProvider<Bitmap> {
29+
private val model: OnnxModel = when(modelSource){
30+
is FilePath -> OnnxModel(FileOnnxLoader(modelSource.path))
31+
is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId))
32+
}
33+
34+
companion object {
35+
private const val TAG = "FaceDetector"
36+
const val DIM_BATCH_SIZE = 1
37+
const val DIM_PIXEL_SIZE = 3
38+
const val IMAGE_SIZE_X = 320
39+
const val IMAGE_SIZE_Y = 240
40+
}
41+
42+
override suspend fun initialize() = model.loadModel()
43+
44+
override fun isInitialized() = model.isLoaded()
45+
46+
private var closed = false
47+
48+
override suspend fun detect(data: Bitmap): Pair<List<Float>, List<FloatArray>> = withContext(Dispatchers.Default) {
49+
val startTime = System.currentTimeMillis()
50+
val inputShape = longArrayOf(DIM_BATCH_SIZE.toLong(), DIM_PIXEL_SIZE.toLong(), IMAGE_SIZE_Y.toLong(), IMAGE_SIZE_X.toLong())
51+
val imgData: FloatBuffer = preProcess(data)
52+
val inputName = model.getInputNames()?.firstOrNull() ?: throw IllegalStateException("Model inputs not available")
53+
val outputs = model.run(mapOf(inputName to TensorData.FloatBufferTensor(imgData, inputShape)))
54+
55+
val outputList = outputs.values.toList()
56+
@Suppress("UNCHECKED_CAST")
57+
val scoresRawFull = outputList[0] as Array<Array<FloatArray>>
58+
@Suppress("UNCHECKED_CAST")
59+
val boxesRawFull = outputList[1] as Array<Array<FloatArray>>
60+
61+
// Extract the first element (batch dimension)
62+
val scoresRaw = scoresRawFull[0] // shape: [num_boxes, 2]
63+
val boxesRaw = boxesRawFull[0] // shape: [num_boxes, 4]
64+
65+
val imgWidth = data.width
66+
val imgHeight = data.height
67+
68+
val boxesList = mutableListOf<FloatArray>()
69+
val scoresList = mutableListOf<Float>()
70+
for (i in scoresRaw.indices) {
71+
val faceScore = scoresRaw[i][1]
72+
if (faceScore > confThreshold) {
73+
val box = boxesRaw[i]
74+
// Box values are normalized; convert to absolute pixel coordinates.
75+
val x1 = box[0] * imgWidth
76+
val y1 = box[1] * imgHeight
77+
val x2 = box[2] * imgWidth
78+
val y2 = box[3] * imgHeight
79+
boxesList.add(floatArrayOf(x1, y1, x2, y2))
80+
scoresList.add(faceScore)
81+
}
82+
}
83+
84+
val inferenceTime = System.currentTimeMillis() - startTime
85+
Log.d(TAG, "Detection Inference Time: $inferenceTime ms")
86+
87+
// Apply NMS if any detection exists.
88+
if (boxesList.isNotEmpty()) {
89+
val keepIndices = nms(boxesList, scoresList, nmsThreshold)
90+
val filteredBoxes = keepIndices.map { boxesList[it] }
91+
val filteredScores = keepIndices.map { scoresList[it] }
92+
return@withContext Pair(filteredScores, filteredBoxes)
93+
} else {
94+
return@withContext Pair(emptyList<Float>(), emptyList<FloatArray>())
95+
}
96+
}
97+
98+
private fun preProcess(bitmap: Bitmap): FloatBuffer {
99+
val resizedBitmap = bitmap.scale(IMAGE_SIZE_X, IMAGE_SIZE_Y)
100+
val width = resizedBitmap.width
101+
val height = resizedBitmap.height
102+
val intValues = IntArray(width * height)
103+
resizedBitmap.getPixels(intValues, 0, width, 0, 0, width, height)
104+
105+
val floatArray = FloatArray(DIM_PIXEL_SIZE * height * width)
106+
107+
// Process each pixel and store them in channel-first order.
108+
// Channel 0: indices 0 .. height*width-1, etc.
109+
for (i in 0 until height) {
110+
for (j in 0 until width) {
111+
val pixel = intValues[i * width + j]
112+
val r = ((pixel shr 16) and 0xFF).toFloat()
113+
val g = ((pixel shr 8) and 0xFF).toFloat()
114+
val b = (pixel and 0xFF).toFloat()
115+
116+
// Normalize channels
117+
val normalizedR = (r - 127f) / 128f
118+
val normalizedG = (g - 127f) / 128f
119+
val normalizedB = (b - 127f) / 128f
120+
121+
val index = i * width + j
122+
floatArray[index] = normalizedR
123+
floatArray[height * width + index] = normalizedG
124+
floatArray[2 * height * width + index] = normalizedB
125+
}
126+
}
127+
128+
val byteBuffer = ByteBuffer.allocateDirect(floatArray.size * 4).order(ByteOrder.nativeOrder())
129+
val floatBuffer = byteBuffer.asFloatBuffer()
130+
floatBuffer.put(floatArray)
131+
floatBuffer.position(0)
132+
return floatBuffer
133+
}
134+
135+
override fun closeSession() {
136+
if (closed) return
137+
closed = true
138+
(model as? AutoCloseable)?.close()
139+
}
140+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package com.fpf.smartscansdk.ml.models.providers.embeddings.inception
2+
3+
import android.content.Context
4+
import android.graphics.Bitmap
5+
import com.fpf.smartscansdk.core.embeddings.ImageEmbeddingProvider
6+
import com.fpf.smartscansdk.core.media.centerCrop
7+
import com.fpf.smartscansdk.core.processors.BatchProcessor
8+
import com.fpf.smartscansdk.ml.data.FilePath
9+
import com.fpf.smartscansdk.ml.data.ModelSource
10+
import com.fpf.smartscansdk.ml.data.ResourceId
11+
import com.fpf.smartscansdk.ml.data.TensorData
12+
import com.fpf.smartscansdk.ml.models.FileOnnxLoader
13+
import com.fpf.smartscansdk.ml.models.OnnxModel
14+
import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader
15+
import kotlinx.coroutines.Dispatchers
16+
import kotlinx.coroutines.withContext
17+
import java.nio.FloatBuffer
18+
19+
class InceptionResnetFaceEmbedder(
20+
private val context: Context,
21+
modelSource: ModelSource,
22+
) : ImageEmbeddingProvider {
23+
private val model: OnnxModel = when(modelSource){
24+
is FilePath -> OnnxModel(FileOnnxLoader(modelSource.path))
25+
is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId))
26+
}
27+
28+
companion object {
29+
private const val TAG = "FaceEmbedder"
30+
const val DIM_BATCH_SIZE = 1
31+
const val DIM_PIXEL_SIZE = 3
32+
const val IMAGE_SIZE_X = 160
33+
const val IMAGE_SIZE_Y = 160
34+
val MEAN= floatArrayOf(0.485f, 0.456f, 0.406f)
35+
val STD=floatArrayOf(0.229f, 0.224f, 0.225f)
36+
}
37+
38+
override val embeddingDim: Int = 512
39+
private var closed = false
40+
41+
override suspend fun initialize() = model.loadModel()
42+
43+
override fun isInitialized() = model.isLoaded()
44+
45+
override suspend fun embed(data: Bitmap): FloatArray = withContext(Dispatchers.Default) {
46+
if (!isInitialized()) throw IllegalStateException("Model not initialized")
47+
48+
val imgData = preProcess(data)
49+
val inputShape = longArrayOf(DIM_BATCH_SIZE.toLong(), DIM_PIXEL_SIZE.toLong(), IMAGE_SIZE_X.toLong(), IMAGE_SIZE_Y.toLong())
50+
val inputName = model.getInputNames()?.firstOrNull() ?: throw IllegalStateException("Model inputs not available")
51+
val output = model.run(mapOf(inputName to TensorData.FloatBufferTensor(imgData, inputShape)))
52+
(output.values.first() as Array<FloatArray>)[0]
53+
}
54+
55+
override suspend fun embedBatch(data: List<Bitmap>): List<FloatArray> {
56+
val allEmbeddings = mutableListOf<FloatArray>()
57+
58+
val processor = object : BatchProcessor<Bitmap, FloatArray>(context = context.applicationContext) {
59+
override suspend fun onProcess(context: Context, item: Bitmap): FloatArray {
60+
return embed(item)
61+
}
62+
override suspend fun onBatchComplete(context: Context, batch: List<FloatArray>) {
63+
allEmbeddings.addAll(batch)
64+
}
65+
}
66+
67+
processor.run(data)
68+
return allEmbeddings
69+
}
70+
71+
private fun preProcess(bitmap: Bitmap): FloatBuffer {
72+
val centredBitmap = centerCrop(bitmap, IMAGE_SIZE_X)
73+
val imgData = FloatBuffer.allocate(DIM_BATCH_SIZE * DIM_PIXEL_SIZE * IMAGE_SIZE_X * IMAGE_SIZE_Y)
74+
imgData.rewind()
75+
val stride = IMAGE_SIZE_X * IMAGE_SIZE_Y
76+
val bmpData = IntArray(stride)
77+
centredBitmap.getPixels(bmpData, 0, centredBitmap.width, 0, 0, centredBitmap.width, centredBitmap.height)
78+
for (i in 0..IMAGE_SIZE_X - 1) {
79+
for (j in 0..IMAGE_SIZE_Y - 1) {
80+
val idx = IMAGE_SIZE_Y * i + j
81+
val pixelValue = bmpData[idx]
82+
imgData.put(idx, (((pixelValue shr 16 and 0xFF) / 255f - MEAN[0]) / STD[0]))
83+
imgData.put(idx + stride, (((pixelValue shr 8 and 0xFF) / 255f - MEAN[1]) / STD[1]))
84+
imgData.put(idx + stride * 2, (((pixelValue and 0xFF) / 255f - MEAN[2]) / STD[2]))
85+
}
86+
}
87+
88+
imgData.rewind()
89+
return imgData
90+
}
91+
92+
override fun closeSession() {
93+
if (closed) return
94+
closed = true
95+
(model as? AutoCloseable)?.close()
96+
}
97+
}

0 commit comments

Comments
 (0)