Skip to content

Commit dbeedb2

Browse files
Merge pull request #24 from smartscanapp/staging
SmartScan SDK V1.2.0
2 parents 7d23c73 + b54fc77 commit dbeedb2

File tree

43 files changed

+646
-345
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+646
-345
lines changed

core/src/androidTest/kotlin/com/fpf/smartscansdk/core/data/images/Entity.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.fpf.smartscansdk.core.data.images
22

33
import androidx.room.*
4-
import com.fpf.smartscansdk.core.data.Embedding
4+
import com.fpf.smartscansdk.core.embeddings.Embedding
55

66
@Entity(tableName = "image_embeddings")
77
data class ImageEmbeddingEntity(

core/src/main/java/com/fpf/smartscansdk/core/data/Classification.kt

Lines changed: 0 additions & 9 deletions
This file was deleted.

core/src/main/java/com/fpf/smartscansdk/core/data/Processors.kt

Lines changed: 0 additions & 20 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
package com.fpf.smartscansdk.ml.models.providers.embeddings
1+
package com.fpf.smartscansdk.core.embeddings
22

3-
import com.fpf.smartscansdk.core.data.ClassificationError
4-
import com.fpf.smartscansdk.core.data.ClassificationResult
5-
import com.fpf.smartscansdk.core.data.PrototypeEmbedding
6-
import com.fpf.smartscansdk.core.embeddings.getSimilarities
7-
import com.fpf.smartscansdk.core.embeddings.getTopN
3+
sealed class ClassificationResult {
4+
data class Success(val classId: String, val similarity: Float ): ClassificationResult()
5+
data class Failure(val error: ClassificationError ): ClassificationResult()
6+
}
87

8+
enum class ClassificationError{MINIMUM_CLASS_SIZE, THRESHOLD, CONFIDENCE_MARGIN}
99

10-
fun classify(embedding: FloatArray, classPrototypes: List<PrototypeEmbedding>, threshold: Float = 0.4f, confidenceMargin: Float = 0.05f ): ClassificationResult{
10+
fun fewShotClassify(embedding: FloatArray, classPrototypes: List<PrototypeEmbedding>, threshold: Float = 0.4f, confidenceMargin: Float = 0.05f ): ClassificationResult{
1111
if(classPrototypes.size < 2) return ClassificationResult.Failure(error= ClassificationError.MINIMUM_CLASS_SIZE) // Using a single class prototype leads to many false positives
1212

1313
// No threshold filter applied here to allow confidence check by comparing top 2 matches
@@ -24,6 +24,4 @@ fun classify(embedding: FloatArray, classPrototypes: List<PrototypeEmbedding>, t
2424

2525
val classId = classPrototypes[bestIndex].id
2626
return ClassificationResult.Success(classId=classId, similarity = bestSim)
27-
}
28-
29-
27+
}

core/src/main/java/com/fpf/smartscansdk/core/embeddings/EmbeddingUtils.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ fun unflattenEmbeddings(flattened: FloatArray, embeddingDim: Int): List<FloatArr
5656
}
5757
return embeddings
5858
}
59+
60+
61+
62+

core/src/main/java/com/fpf/smartscansdk/core/data/Embedings.kt renamed to core/src/main/java/com/fpf/smartscansdk/core/embeddings/Embeddings.kt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
package com.fpf.smartscansdk.core.data
2-
1+
package com.fpf.smartscansdk.core.embeddings
32

43
// `Embedding` represents a raw vector for a single media item, with `id` corresponding to its `MediaStoreId`.
54
data class Embedding(
@@ -14,7 +13,3 @@ data class PrototypeEmbedding(
1413
val date: Long,
1514
val embeddings: FloatArray
1615
)
17-
18-
19-
20-

core/src/main/java/com/fpf/smartscansdk/core/embeddings/FileEmbeddingRetriever.kt

Lines changed: 0 additions & 43 deletions
This file was deleted.

core/src/main/java/com/fpf/smartscansdk/core/embeddings/FileEmbeddingStore.kt

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package com.fpf.smartscansdk.core.embeddings
22

33
import android.util.Log
4-
import com.fpf.smartscansdk.core.data.Embedding
54
import kotlinx.coroutines.Dispatchers
65
import kotlinx.coroutines.withContext
76
import java.io.File
@@ -12,6 +11,7 @@ import java.io.RandomAccessFile
1211
import java.nio.ByteBuffer
1312
import java.nio.ByteOrder
1413
import java.nio.channels.FileChannel
14+
import kotlin.collections.map
1515

1616
class FileEmbeddingStore(
1717
private val file: File,
@@ -25,13 +25,10 @@ class FileEmbeddingStore(
2525
}
2626

2727
private var cache: LinkedHashMap<Long, Embedding>? = null
28+
private var cachedIds: List<Long>? = null
2829

2930
override val exists: Boolean get() = file.exists()
3031

31-
override val isCached: Boolean
32-
get() = cache != null
33-
34-
3532
// prevent OOM in FileEmbeddingStore.save() by batching writes
3633
private suspend fun save(embeddingsList: List<Embedding>): Unit = withContext(Dispatchers.IO) {
3734
if (embeddingsList.isEmpty()) return@withContext
@@ -99,6 +96,19 @@ class FileEmbeddingStore(
9996
}
10097
}
10198

99+
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
100+
val map = cache ?: run {
101+
val all = get()
102+
LinkedHashMap(all.associateBy { it.id })
103+
}
104+
val embeddings = mutableListOf<Embedding>()
105+
106+
for (id in ids) {
107+
map.get(id)?.let { embeddings.add(it) }
108+
}
109+
embeddings
110+
}
111+
102112
override suspend fun add(newEmbeddings: List<Embedding>): Unit = withContext(Dispatchers.IO) {
103113
if (newEmbeddings.isEmpty()) return@withContext
104114

@@ -185,29 +195,42 @@ class FileEmbeddingStore(
185195
}
186196
}
187197

188-
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
189-
val map = cache ?: run {
190-
val all = get()
191-
LinkedHashMap(all.associateBy { it.id })
192-
}
193-
val embeddings = mutableListOf<Embedding>()
194198

195-
for (id in ids) {
196-
map.get(id)?.let { embeddings.add(it) }
197-
}
198-
embeddings
199+
override fun clear(){
200+
cache = null
199201
}
200202

201-
suspend fun get(id: Long): Embedding? = withContext(Dispatchers.IO) {
202-
val map = cache ?: run {
203-
val all = get()
204-
LinkedHashMap(all.associateBy { it.id })
203+
204+
override suspend fun query(embedding: FloatArray, topK: Int, threshold: Float): List<Embedding> {
205+
206+
cachedIds = null // clear on new search
207+
208+
val storedEmbeddings = get()
209+
210+
if (storedEmbeddings.isEmpty()) return emptyList()
211+
212+
val similarities = getSimilarities(embedding, storedEmbeddings.map { it.embeddings })
213+
val resultIndices = getTopN(similarities, topK, threshold)
214+
215+
if (resultIndices.isEmpty()) return emptyList()
216+
217+
val idsToCache = mutableListOf<Long>()
218+
val results = resultIndices.map{idx ->
219+
idsToCache.add( storedEmbeddings[idx].id)
220+
storedEmbeddings[idx]
205221
}
206-
map.get(id)
222+
cachedIds = idsToCache
223+
return results
207224
}
208225

209-
override fun clear(){
210-
cache = null
226+
suspend fun query(start: Int, end: Int): List<Embedding> {
227+
val ids = cachedIds ?: return emptyList()
228+
val s = start.coerceAtLeast(0)
229+
val e = end.coerceAtMost(ids.size)
230+
if (s >= e) return emptyList()
231+
232+
val batch = get(ids.subList(s, e))
233+
return batch
211234
}
212235

213236
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.fpf.smartscansdk.core.embeddings
2+
3+
interface IDetectorProvider<T> {
4+
fun closeSession() = Unit
5+
suspend fun initialize()
6+
fun isInitialized(): Boolean
7+
suspend fun detect(data: T): Pair<List<Float>, List<FloatArray>>
8+
}
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package com.fpf.smartscansdk.core.embeddings
22

3-
import com.fpf.smartscansdk.core.data.Embedding
4-
53
interface IEmbeddingStore {
6-
val isCached: Boolean
74
val exists: Boolean
85
suspend fun add(newEmbeddings: List<Embedding>)
96
suspend fun remove(ids: List<Long>)
107
suspend fun get(): List<Embedding>
118
fun clear()
9+
10+
suspend fun query(
11+
embedding: FloatArray,
12+
topK: Int,
13+
threshold: Float
14+
): List<Embedding>
1215
}

0 commit comments

Comments
 (0)