Skip to content

Commit 05f2efd

Browse files
authored
Expose zero-shot labeling to Python (#73)
* rm __pycache__ * gitignore __pycache__ * gitignore dist * Upd usearch in image-search example * Upd usearch in image-search example * Implement ZSL in clip lib * Use new ZSL API in examples * Expose ZSL in Python * Upd readme in Python bindings * Bump version in Python bindings
1 parent c9c02cb commit 05f2efd

File tree

9 files changed

+156
-62
lines changed

9 files changed

+156
-62
lines changed

clip.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,37 @@ bool softmax_with_sorting(float * arr, const int length, float * sorted_scores,
15181518
return true;
15191519
}
15201520

1521+
bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img,
1522+
const char ** labels, const size_t n_labels, float * scores, int * indices) {
1523+
// load the image
1524+
clip_image_f32 img_res;
1525+
1526+
const int vec_dim = clip_get_vision_hparams(ctx)->projection_dim;
1527+
1528+
clip_image_preprocess(ctx, input_img, &img_res);
1529+
1530+
float img_vec[vec_dim];
1531+
if (!clip_image_encode(ctx, n_threads, &img_res, img_vec, false)) {
1532+
return false;
1533+
}
1534+
1535+
// encode texts and compute similarities
1536+
float txt_vec[vec_dim];
1537+
float similarities[n_labels];
1538+
1539+
for (int i = 0; i < n_labels; i++) {
1540+
const auto & text = labels[i];
1541+
auto tokens = clip_tokenize(ctx, text);
1542+
clip_text_encode(ctx, n_threads, &tokens, txt_vec, false);
1543+
similarities[i] = clip_similarity_score(img_vec, txt_vec, vec_dim);
1544+
}
1545+
1546+
// apply softmax and sort scores
1547+
softmax_with_sorting(similarities, n_labels, scores, indices);
1548+
1549+
return true;
1550+
}
1551+
15211552
bool image_normalize(const clip_image_u8 * img, clip_image_f32 * res) {
15221553
if (img->nx != 224 || img->ny != 224) {
15231554
printf("%s: long input shape: %d x %d\n", __func__, img->nx, img->ny);

clip.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ bool clip_compare_text_and_image(const struct clip_ctx * ctx, const int n_thread
9898
const struct clip_image_u8 * image, float * score);
9999
float clip_similarity_score(const float * vec1, const float * vec2, const int vec_dim);
100100
bool softmax_with_sorting(float * arr, const int length, float * sorted_scores, int * indices);
101+
bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img,
102+
const char ** labels, const size_t n_labels, float * scores, int * indices);
101103

102104
#ifdef __cplusplus
103105
}

examples/common-clip.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@ std::map<std::string, std::vector<std::string>> get_dir_keyed_files(const std::s
1515

1616
bool is_image_file_extension(const std::string & path);
1717

18-
struct app_params {
19-
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
18+
#include <algorithm>
19+
#include <string>
20+
#include <vector>
2021

21-
std::string model = "models/ggml-model-f16.bin";
22+
struct app_params {
23+
int32_t n_threads;
24+
std::string model;
2225
std::vector<std::string> image_paths;
2326
std::vector<std::string> texts;
24-
int verbose = 1;
27+
int verbose;
28+
29+
app_params()
30+
: n_threads(std::min(4, static_cast<int32_t>(std::thread::hardware_concurrency()))), model("models/ggml-model-f16.bin"),
31+
verbose(1) {
32+
// Initialize other fields if needed
33+
}
2534
};
2635

2736
bool app_params_parse(int argc, char ** argv, app_params & params);

examples/image-search/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set(CXX_STANDARD_REQUIRED ON)
55
include(FetchContent)
66
FetchContent_Declare(usearch
77
GIT_REPOSITORY https://github.com/unum-cloud/usearch.git
8-
GIT_TAG v0.20.0
8+
GIT_TAG v2.5.0
99
)
1010
FetchContent_MakeAvailable(usearch)
1111

examples/python_bindings/README.md

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,20 @@ def compare_text_and_image(
159159
- `image_path` (str): The path to the image file for comparison.
160160
- `n_threads` (int, optional): The number of CPU threads to use for encoding (default is the number of CPU cores).
161161

162-
#### 8. `__del__`
162+
## 8. `zero_shot_label_image`
163+
164+
```python
165+
def zero_shot_label_image(
166+
self, image_path: str, labels: List[str], n_threads: int = os.cpu_count()
167+
) -> Tuple[List[float], List[int]]:
168+
```
169+
170+
- **Description**: Zero-shot labels an image with given candidate labels, returning a tuple of sorted scores and indices.
171+
- `image_path` (str): The path to the image file to be labelled.
172+
- `labels` (List[str]): A list of candidate labels to be scored.
173+
- `n_threads` (int, optional): The number of CPU threads to use for encoding (default is the number of CPU cores).
174+
175+
#### 9. `__del__`
163176

164177
```python
165178
def __del__(self):
@@ -175,17 +188,19 @@ A basic example can be found in the [clip.cpp examples](https://github.com/monat
175188

176189
```
177190
python example_main.py --help
178-
usage: clip [-h] -m MODEL [-v VERBOSITY] -t TEXT -i IMAGE
179-
180-
optional arguments:
181-
-h, --help show this help message and exit
182-
-m MODEL, --model MODEL
183-
path to GGML file
184-
-v VERBOSITY, --verbosity VERBOSITY
185-
Level of verbosity. 0 = minimum, 2 = maximum
186-
-t TEXT, --text TEXT text to encode
187-
-i IMAGE, --image IMAGE
188-
path to an image file
189-
```
191+
usage: clip [-h] -m MODEL [-fn FILENAME] [-v VERBOSITY] -t TEXT [TEXT ...] -i IMAGE
192+
193+
optional arguments:
194+
-h, --help show this help message and exit
195+
-m MODEL, --model MODEL
196+
path to GGML file or repo_id
197+
-fn FILENAME, --filename FILENAME
198+
path to GGML file in the Hugging face repo
199+
-v VERBOSITY, --verbosity VERBOSITY
200+
Level of verbosity. 0 = minimum, 2 = maximum
201+
-t TEXT [TEXT ...], --text TEXT [TEXT ...]
202+
text to encode. Multiple values allowed. In this case, apply zero-shot labeling
203+
-i IMAGE, --image IMAGE
204+
path to an image file
205+
``````
190206
191-
Bindings to the DLL are implemented in `clip_cpp/clip.py` and

examples/python_bindings/clip_cpp/clip.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import platform
44
from glob import glob
55
from pathlib import Path
6-
from typing import List, Dict, Any, Optional
6+
from typing import List, Dict, Any, Optional, Tuple
77

88
from .file_download import ModelInfo, model_download, model_info
99

@@ -167,6 +167,18 @@ class ClipContext(ctypes.Structure):
167167
]
168168
clip_similarity_score.restype = ctypes.c_float
169169

170+
clip_zero_shot_label_image = clip_lib.clip_zero_shot_label_image
171+
clip_zero_shot_label_image.argtypes = [
172+
ctypes.POINTER(ClipContext),
173+
ctypes.c_int,
174+
ctypes.POINTER(ClipImageU8),
175+
ctypes.POINTER(ctypes.c_char_p),
176+
ctypes.c_ssize_t,
177+
ctypes.POINTER(ctypes.c_float),
178+
ctypes.POINTER(ctypes.c_int),
179+
]
180+
clip_zero_shot_label_image.restype = ctypes.c_bool
181+
170182
softmax_with_sorting = clip_lib.softmax_with_sorting
171183
softmax_with_sorting.argtypes = [
172184
ctypes.POINTER(ctypes.c_float),
@@ -369,6 +381,34 @@ def compare_text_and_image(
369381

370382
return score.value
371383

384+
def zero_shot_label_image(
385+
self, image_path: str, labels: List[str], n_threads: int = os.cpu_count()
386+
) -> Tuple[List[float], List[int]]:
387+
n_labels = len(labels)
388+
if n_labels < 2:
389+
raise ValueError(
390+
"You must pass at least 2 labels for zero-shot image labeling"
391+
)
392+
393+
labels = (ctypes.c_char_p * n_labels)(
394+
*[ctypes.c_char_p(label.encode("utf8")) for label in labels]
395+
)
396+
image_ptr = make_clip_image_u8()
397+
if not clip_image_load_from_file(image_path.encode("utf8"), image_ptr):
398+
raise RuntimeError(f"Could not load image {image_path}")
399+
400+
scores = (ctypes.c_float * n_labels)()
401+
indices = (ctypes.c_int * n_labels)()
402+
if not clip_zero_shot_label_image(
403+
self.ctx, n_threads, image_ptr, labels, n_labels, scores, indices
404+
):
405+
print("function called")
406+
raise RuntimeError("Could not zero-shot label image")
407+
408+
return [scores[i] for i in range(n_labels)], [
409+
indices[i] for i in range(n_labels)
410+
]
411+
372412
def __del__(self):
373413
if hasattr(self, "ctx"):
374414
clip_free(self.ctx)

examples/python_bindings/example_main.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,39 @@
55
if __name__ == "__main__":
66
ap = argparse.ArgumentParser(prog="clip")
77
ap.add_argument("-m", "--model", help="path to GGML file or repo_id", required=True)
8-
ap.add_argument("-fn", "--filename", help="path to GGML file in the Hugging face repo", required=False)
8+
ap.add_argument(
9+
"-fn",
10+
"--filename",
11+
help="path to GGML file in the Hugging face repo",
12+
required=False,
13+
)
914
ap.add_argument(
1015
"-v",
1116
"--verbosity",
1217
type=int,
1318
help="Level of verbosity. 0 = minimum, 2 = maximum",
1419
default=0,
1520
)
16-
ap.add_argument("-t", "--text", help="text to encode", required=True)
21+
ap.add_argument(
22+
"-t",
23+
"--text",
24+
help="text to encode. Multiple values allowed. In this case, apply zero-shot labeling",
25+
nargs="+",
26+
type=str,
27+
required=True,
28+
)
1729
ap.add_argument("-i", "--image", help="path to an image file", required=True)
1830
args = ap.parse_args()
1931

2032
clip = Clip(args.model, args.verbosity)
21-
22-
tokens = clip.tokenize(args.text)
23-
text_embed = clip.encode_text(tokens)
24-
25-
image_embed = clip.load_preprocess_encode_image(args.image)
26-
27-
score = clip.calculate_similarity(text_embed, image_embed)
28-
29-
# Alternatively, you can just do:
30-
# score = clip.compare_text_and_image(text, image_path)
31-
32-
print(f"Similarity score: {score}")
33+
if len(args.text) == 1:
34+
score = clip.compare_text_and_image(args.text[0], args.image)
35+
36+
print(f"Similarity score: {score}")
37+
else:
38+
sorted_scores, sorted_indices = clip.zero_shot_label_image(
39+
args.image, args.text
40+
)
41+
for ind, score in zip(sorted_indices, sorted_scores):
42+
label = args.text[ind]
43+
print(f"{label}: {score:.4f}")

examples/python_bindings/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "clip_cpp"
3-
version = "0.4.1"
3+
version = "0.4.2"
44
description = "CLIP inference with no big dependencies as PyTorch, TensorFlow, Numpy"
55
authors = ["Yusuf Sarıgöz <[email protected]>"]
66
packages = [{ include = "clip_cpp" }]

examples/zsl.cpp

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@ int main(int argc, char ** argv) {
1010
return 1;
1111
}
1212

13-
int n_labels = params.texts.size();
13+
const size_t n_labels = params.texts.size();
1414
if (n_labels < 2) {
1515
printf("%s: You must specify at least 2 texts for zero-shot labeling\n", __func__);
1616
}
1717

18+
const char * labels[n_labels];
19+
for (size_t i = 0; i < n_labels; ++i) {
20+
labels[i] = params.texts[i].c_str();
21+
}
22+
1823
auto ctx = clip_model_load(params.model.c_str(), params.verbose);
1924
if (!ctx) {
2025
printf("%s: Unable to load model from %s", __func__, params.model.c_str());
@@ -23,40 +28,21 @@ int main(int argc, char ** argv) {
2328

2429
// load the image
2530
const auto & img_path = params.image_paths[0].c_str();
26-
clip_image_u8 img0;
27-
clip_image_f32 img_res;
28-
if (!clip_image_load_from_file(img_path, &img0)) {
31+
clip_image_u8 input_img;
32+
if (!clip_image_load_from_file(img_path, &input_img)) {
2933
fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, img_path);
3034
return 1;
3135
}
3236

33-
const int vec_dim = clip_get_vision_hparams(ctx)->projection_dim;
34-
35-
clip_image_preprocess(ctx, &img0, &img_res);
36-
37-
float img_vec[vec_dim];
38-
if (!clip_image_encode(ctx, params.n_threads, &img_res, img_vec, false)) {
37+
float sorted_scores[n_labels];
38+
int sorted_indices[n_labels];
39+
if (!clip_zero_shot_label_image(ctx, params.n_threads, &input_img, labels, n_labels, sorted_scores, sorted_indices)) {
40+
fprintf(stderr, "Unable to apply ZSL\n");
3941
return 1;
4042
}
4143

42-
// encode texts and compute similarities
43-
float txt_vec[vec_dim];
44-
float similarities[n_labels];
45-
46-
for (int i = 0; i < n_labels; i++) {
47-
const auto & text = params.texts[i].c_str();
48-
auto tokens = clip_tokenize(ctx, text);
49-
clip_text_encode(ctx, params.n_threads, &tokens, txt_vec, false);
50-
similarities[i] = clip_similarity_score(img_vec, txt_vec, vec_dim);
51-
}
52-
53-
// apply softmax and sort scores
54-
float sorted_scores[n_labels];
55-
int indices[n_labels];
56-
softmax_with_sorting(similarities, n_labels, sorted_scores, indices);
57-
5844
for (int i = 0; i < n_labels; i++) {
59-
auto label = params.texts[indices[i]].c_str();
45+
auto label = labels[sorted_indices[i]];
6046
float score = sorted_scores[i];
6147
printf("%s = %1.4f\n", label, score);
6248
}

0 commit comments

Comments
 (0)