Skip to content

Commit 7155e45

Browse files
Merge pull request #1 from dev-diaries41/bugs/indexing
Fix bug causing full reindexing and duplicates
2 parents b78719c + da62aac commit 7155e45

File tree

11 files changed

+87
-54
lines changed

11 files changed

+87
-54
lines changed

core/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ val gitVersion: String by lazy {
9999
publishing {
100100
publications {
101101
register<MavenPublication>("release") {
102-
groupId = "com.github.dev-diaries41"
102+
groupId = "com.github.dev-diaries41.smartscan-sdk"
103103
artifactId = "smartscan-${project.name}"
104104
version = gitVersion
105105

core/src/androidTest/kotlin/com/fpf/smartscansdk/core/ml/embeddings/clip/ClipImageEmbedderTest.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.fpf.smartscansdk.core.ml.embeddings.clip
22

33
import ai.onnxruntime.OnnxTensor
4+
import ai.onnxruntime.OnnxTensorLike
45
import ai.onnxruntime.OrtEnvironment
56
import android.content.Context
67
import android.content.res.Resources
@@ -10,6 +11,7 @@ import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig.IMAGE_SIZE_X
1011
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig.IMAGE_SIZE_Y
1112
import com.fpf.smartscansdk.core.ml.models.OnnxModel
1213
import com.fpf.smartscansdk.core.ml.models.ResourceId
14+
import com.fpf.smartscansdk.core.ml.models.TensorData
1315
import io.mockk.*
1416
import kotlinx.coroutines.runBlocking
1517
import org.junit.After
@@ -77,7 +79,7 @@ class ClipImageEmbedderInstrumentedTest {
7779

7880
// prepare fake output: Array<FloatArray> with length 512
7981
val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } }
80-
every { mockModel.run(any<Map<String, ai.onnxruntime.OnnxTensorLike>>()) } returns mapOf("out" to raw)
82+
every { mockModel.run(any<Map<String, TensorData>>()) } returns mapOf("out" to raw)
8183

8284
// mock tensor creation and closing; specify types so mockk can infer overload
8385
val mockTensor = mockk<OnnxTensor>(relaxed = true)
@@ -109,7 +111,7 @@ class ClipImageEmbedderInstrumentedTest {
109111
every { mockModel.getEnv() } returns mockk<OrtEnvironment>()
110112

111113
val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } }
112-
every { mockModel.run(any<Map<String, ai.onnxruntime.OnnxTensorLike>>()) } returns mapOf("out" to raw)
114+
every { mockModel.run(any<Map<String, TensorData>>()) } returns mapOf("out" to raw)
113115

114116
val mockTensor = mockk<OnnxTensor>(relaxed = true)
115117
every { OnnxTensor.createTensor(any<OrtEnvironment>(), any<java.nio.FloatBuffer>(), any<LongArray>()) } returns mockTensor

core/src/androidTest/kotlin/com/fpf/smartscansdk/core/ml/embeddings/clip/ClipTextEmbedderTest.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import android.content.res.Resources
77
import androidx.test.core.app.ApplicationProvider
88
import com.fpf.smartscansdk.core.ml.models.OnnxModel
99
import com.fpf.smartscansdk.core.ml.models.ResourceId
10+
import com.fpf.smartscansdk.core.ml.models.TensorData
1011
import io.mockk.*
1112
import kotlinx.coroutines.runBlocking
1213
import org.junit.After
@@ -59,7 +60,7 @@ class ClipTextEmbedderInstrumentedTest {
5960
every { mockModel.getEnv() } returns mockk<OrtEnvironment>()
6061

6162
val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } }
62-
every { mockModel.run(any<Map<String, ai.onnxruntime.OnnxTensorLike>>()) } returns mapOf("out" to raw)
63+
every { mockModel.run(any<Map<String, TensorData>>()) } returns mapOf("out" to raw)
6364

6465
val mockTensor = mockk<OnnxTensor>(relaxed = true)
6566
every { OnnxTensor.createTensor(any<OrtEnvironment>(), any<LongBuffer>(), any<LongArray>()) } returns mockTensor
@@ -85,7 +86,7 @@ class ClipTextEmbedderInstrumentedTest {
8586
every { mockModel.getEnv() } returns mockk<OrtEnvironment>()
8687

8788
val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } }
88-
every { mockModel.run(any<Map<String, ai.onnxruntime.OnnxTensorLike>>()) } returns mapOf("out" to raw)
89+
every { mockModel.run(any<Map<String, TensorData>>()) } returns mapOf("out" to raw)
8990

9091
val mockTensor = mockk<OnnxTensor>(relaxed = true)
9192
every { OnnxTensor.createTensor(any<OrtEnvironment>(), any<LongBuffer>(), any<LongArray>()) } returns mockTensor

core/src/androidTest/kotlin/com/fpf/smartscansdk/core/ml/models/OnnxModelTest.kt

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import org.junit.Test
1919
import org.junit.Assert.assertTrue
2020
import org.junit.Assert.assertFalse
2121
import org.junit.Assert.assertEquals
22+
import java.nio.FloatBuffer
2223

2324

2425
class OnnxModelInstrumentedTest {
@@ -40,6 +41,7 @@ class OnnxModelInstrumentedTest {
4041

4142
// Mock session creation
4243
session = mockk(relaxed = true)
44+
4345
every { mockEnv.createSession(any<ByteArray>()) } returns session
4446

4547
// Construct the model
@@ -69,31 +71,42 @@ class OnnxModelInstrumentedTest {
6971
}
7072

7173
@Test
72-
fun `run returns mapped output`() {
73-
// Mock OnnxValue
74-
val value = mockk<OnnxValue>()
75-
every { value.value } returns floatArrayOf(1.0f)
74+
fun `run returns mapped output`() = runBlocking {
75+
// Prepare a fake FloatBuffer for input
76+
val fakeBuffer = FloatBuffer.allocate(1)
77+
fakeBuffer.put(1.0f)
78+
fakeBuffer.flip()
79+
80+
val tensorData = TensorData.FloatBufferTensor(fakeBuffer, longArrayOf(1))
7681

77-
// Mock entry
78-
val entry = mockk<MutableMap.MutableEntry<String, OnnxValue>>()
79-
every { entry.key } returns "out"
80-
every { entry.value } returns value
82+
// Mock OnnxTensor creation
83+
val onnxTensor = mockk<OnnxTensor>(relaxed = true)
84+
mockkStatic(OnnxTensor::class)
85+
every { OnnxTensor.createTensor(any<OrtEnvironment>(), any<FloatBuffer>(), any<LongArray>()) } returns onnxTensor
8186

82-
// Mock result iterator
87+
// Mock OnnxValue for result
88+
val onnxValue = mockk<OnnxValue>()
89+
every { onnxValue.value } returns floatArrayOf(1.0f)
90+
91+
// Prepare a mutable map entry and mutable iterator
92+
val resultMap = mutableMapOf("out" to onnxValue)
8393
val result = mockk<OrtSession.Result>()
84-
every { result.iterator() } returns mutableListOf(entry).iterator()
94+
every { result.iterator() } returns resultMap.entries.iterator()
8595
every { result.close() } just Runs
8696

87-
// Mock session.run()
88-
every { session.run(any<Map<String, OnnxTensorLike>>()) } returns result
97+
// Mock session.run() to return the mocked result
98+
every { session.run(any<Map<String, OnnxTensor>>()) } returns result
8999

90-
val inputs = mapOf("in" to mockk<OnnxTensorLike>())
100+
// Run the model
101+
val inputs = mapOf("in" to tensorData)
91102
val output = model.run(inputs)
92103

104+
// Assertions
93105
assertTrue(output.containsKey("out"))
94106
assertEquals(floatArrayOf(1.0f).toList(), (output["out"] as FloatArray).toList())
95107
}
96108

109+
97110
@Test
98111
fun `getInputNames returns session input names`() {
99112
// Use a LinkedHashSet to keep deterministic order

core/src/main/java/com/fpf/smartscansdk/core/ml/embeddings/EmbeddingTypes.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ data class PrototypeEmbedding(
1818

1919

2020
interface IEmbeddingStore {
21+
val isCached: Boolean
22+
val exists: Boolean
2123
suspend fun add(newEmbeddings: List<Embedding>)
2224
suspend fun remove(ids: List<Long>)
25+
suspend fun getAll(): List<Embedding> // getAll used instead of get to make clear that loading full index in memory is required
2326
fun clear()
2427
}
2528

extensions/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ val gitVersion: String by lazy {
100100
publishing {
101101
publications {
102102
register<MavenPublication>("release") {
103-
groupId = "com.github.dev-diaries41"
103+
groupId = "com.github.dev-diaries41.smartscan-sdk"
104104
artifactId = "smartscan-${project.name}"
105105
version = gitVersion
106106

extensions/src/androidTest/kotlin/com/fpf/smartscansdk/extensions/embeddings/EmbeddingLoadingBenchmarkTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class EmbeddingLoadingBenchmarkTest {
5555
val tempDir = File(context.cacheDir, "tempEmbeddings").apply { mkdirs() }
5656
val store = FileEmbeddingStore(tempDir, "embeddings.bin", embeddingLength)
5757

58-
store.save(embeddings.map { it.toEmbedding() })
58+
store.add(embeddings.map { it.toEmbedding() })
5959
store.clear() // force reload
6060

6161
val fileTime = measureNanoTime {

extensions/src/main/java/com/fpf/smartscansdk/extensions/embeddings/FileEmbeddingStore.kt

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class FileEmbeddingStore(
1919
filename: String,
2020
private val embeddingLength: Int,
2121
val useCache: Boolean = true,
22-
2322
):
2423
IEmbeddingStore {
2524

@@ -28,18 +27,20 @@ class FileEmbeddingStore(
2827
}
2928

3029
private val file = File(dir, filename)
31-
private var cache: List<Embedding>? = null
30+
private var cache: LinkedHashMap<Long, Embedding>? = null
3231

33-
val exists: Boolean get() = file.exists()
32+
override val exists: Boolean get() = file.exists()
3433

35-
val isCached: Boolean
34+
override val isCached: Boolean
3635
get() = cache != null
3736

3837

3938
// prevent OOM in FileEmbeddingStore.save() by batching writes
40-
suspend fun save(embeddingsList: List<Embedding>): Unit = withContext(Dispatchers.IO) {
39+
private suspend fun save(embeddingsList: List<Embedding>): Unit = withContext(Dispatchers.IO) {
4140
if (embeddingsList.isEmpty()) return@withContext
42-
if(useCache){cache = embeddingsList}
41+
if(useCache){
42+
cache = LinkedHashMap(embeddingsList.associateBy { it.id })
43+
}
4344

4445
FileOutputStream(file).channel.use { channel ->
4546
val header = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN)
@@ -77,15 +78,15 @@ class FileEmbeddingStore(
7778
}
7879

7980
// This explicitly makes clear the design constraints that requires the full index to be loaded in memory
80-
suspend fun getAll(): List<Embedding> = withContext(Dispatchers.IO){
81-
cache?.let { return@withContext it };
81+
override suspend fun getAll(): List<Embedding> = withContext(Dispatchers.IO){
82+
cache?.let { return@withContext it.values.toList() };
8283

8384
FileInputStream(file).channel.use { ch ->
8485
val fileSize = ch.size()
8586
val buffer = ch.map(FileChannel.MapMode.READ_ONLY, 0, fileSize).order(ByteOrder.LITTLE_ENDIAN)
8687

8788
val count = buffer.int
88-
val list = ArrayList<Embedding>(count)
89+
val map = LinkedHashMap<Long, Embedding>(count)
8990

9091
repeat(count) {
9192
val id = buffer.long
@@ -94,10 +95,10 @@ class FileEmbeddingStore(
9495
val fb = buffer.asFloatBuffer()
9596
fb.get(floats)
9697
buffer.position(buffer.position() + embeddingLength * 4)
97-
list.add(Embedding(id, date, floats))
98+
map[id] = Embedding(id, date, floats)
9899
}
99-
if(useCache){cache = list}
100-
list
100+
if (useCache) cache = map
101+
map.values.toList()
101102
}
102103
}
103104

@@ -155,23 +156,39 @@ class FileEmbeddingStore(
155156
}
156157
}
157158
channel.force(false)
158-
if(useCache){cache = (cache ?: emptyList()) + newEmbeddings}
159+
if (useCache) {
160+
val map = cache ?: LinkedHashMap()
161+
for (e in newEmbeddings) map[e.id] = e
162+
cache = map
163+
}
159164
}
160165
}
161166

162167
override suspend fun remove(ids: List<Long>): Unit = withContext(Dispatchers.IO) {
163168
if (ids.isEmpty()) return@withContext
164169

165170
try {
166-
val embeddings = getAll()
167-
val remaining = embeddings.filter { it.id !in ids }
168-
save(remaining)
169-
Log.i(TAG, "Removed ${ids.size} stale embeddings")
171+
val map = cache ?: run {
172+
// Load all embeddings into the map if cache is empty
173+
val all = getAll()
174+
LinkedHashMap(all.associateBy { it.id })
175+
}
176+
177+
var removedCount = 0
178+
for (id in ids) {
179+
if (map.remove(id) != null) removedCount++
180+
}
181+
182+
if (removedCount > 0) {
183+
save(map.values.toList())
184+
Log.i(TAG, "Removed $removedCount stale embeddings")
185+
}
170186
} catch (e: Exception) {
171187
Log.e(TAG, "Error Removing embeddings", e)
172188
}
173189
}
174190

191+
175192
override fun clear(){
176193
cache = null
177194
}

extensions/src/main/java/com/fpf/smartscansdk/extensions/indexers/ImageIndexer.kt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import android.content.Context
66
import android.provider.MediaStore
77
import com.fpf.smartscansdk.core.utils.getBitmapFromUri
88
import com.fpf.smartscansdk.core.ml.embeddings.Embedding
9+
import com.fpf.smartscansdk.core.ml.embeddings.IEmbeddingStore
910
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig
1011
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipImageEmbedder
1112
import com.fpf.smartscansdk.core.processors.BatchProcessor
1213
import com.fpf.smartscansdk.core.processors.IProcessorListener
1314
import com.fpf.smartscansdk.core.processors.ProcessOptions
14-
import com.fpf.smartscansdk.extensions.embeddings.FileEmbeddingStore
1515
import kotlinx.coroutines.NonCancellable
1616
import kotlinx.coroutines.withContext
1717

@@ -25,15 +25,13 @@ class ImageIndexer(
2525
application: Application,
2626
listener: IProcessorListener<Long, Embedding>? = null,
2727
options: ProcessOptions = ProcessOptions(),
28+
private val store: IEmbeddingStore,
2829
): BatchProcessor<Long, Embedding>(application, listener, options){
2930

3031
companion object {
3132
const val INDEX_FILENAME = "image_index.bin"
3233
}
3334

34-
// Cache not needed when indexing. This prevents unnecessary memory usage
35-
private val store = FileEmbeddingStore(application.filesDir, INDEX_FILENAME, ClipConfig.CLIP_EMBEDDING_LENGTH, useCache = false)
36-
3735
override suspend fun onBatchComplete(context: Context, batch: List<Embedding>) {
3836
store.add(batch)
3937
}

extensions/src/main/java/com/fpf/smartscansdk/extensions/indexers/VideoIndexer.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import android.content.ContentUris
55
import android.content.Context
66
import android.provider.MediaStore
77
import com.fpf.smartscansdk.core.ml.embeddings.Embedding
8+
import com.fpf.smartscansdk.core.ml.embeddings.IEmbeddingStore
89
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig.CLIP_EMBEDDING_LENGTH
910
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig.IMAGE_SIZE_X
1011
import com.fpf.smartscansdk.core.ml.embeddings.clip.ClipConfig.IMAGE_SIZE_Y
@@ -30,15 +31,13 @@ class VideoIndexer(
3031
application: Application,
3132
listener: IProcessorListener<Long, Embedding>? = null,
3233
options: ProcessOptions = ProcessOptions(),
33-
): BatchProcessor<Long, Embedding>(application, listener, options){
34+
private val store: IEmbeddingStore,
35+
): BatchProcessor<Long, Embedding>(application, listener, options){
3436

3537
companion object {
3638
const val INDEX_FILENAME = "video_index.bin"
3739
}
3840

39-
// Cache not needed when indexing. This prevents unnecessary memory usage
40-
private val store = FileEmbeddingStore(application.filesDir, INDEX_FILENAME, CLIP_EMBEDDING_LENGTH, useCache = false)
41-
4241
override suspend fun onBatchComplete(context: Context, batch: List<Embedding>) {
4342
store.add(batch)
4443
}

0 commit comments

Comments
 (0)