Skip to content

Commit d694dac

Browse files
committed
cleanup and refactor *again*
1 parent e3261ff commit d694dac

File tree

7 files changed

+151
-65
lines changed

7 files changed

+151
-65
lines changed

examples/llava/llava-cli.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,47 @@ static void show_additional_info(int /*argc*/, char ** argv) {
2020
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
2121
}
2222

23-
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
23+
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
24+
2425
// load and preprocess the image
25-
clip_image_u8 * img = make_clip_image_u8();
26+
llava_image_embed * embed = NULL;
2627
auto prompt = params->prompt;
2728
if (prompt_contains_image(prompt)) {
2829
if (!params->image.empty()) {
2930
printf("using base64 encoded image instead of command line image path\n");
3031
}
31-
if (!clip_image_load_from_prompt(prompt, img)) {
32+
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt);
33+
if (!embed) {
3234
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
33-
return false;
35+
return NULL;
3436
}
3537
params->prompt = remove_image_from_prompt(prompt);
3638
} else {
37-
if (!clip_image_load_from_file(params->image.c_str(), img)) {
39+
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
40+
if (!embed) {
3841
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
39-
return false;
42+
return NULL;
4043
}
4144
}
42-
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos);
43-
if (!image_embed_result) {
44-
clip_image_u8_free(img);
45-
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
46-
return false;
47-
}
4845

49-
return true;
46+
return embed;
5047
}
5148

52-
static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
49+
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const char * prompt) {
5350
int n_past = 0;
5451

5552
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
5653

5754
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
5855
// GG: are we sure that the should be a trailing whitespace at the end of this string?
56+
printf("evaluating system prompt\n");
5957
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
60-
llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
58+
printf("evaluating image embed\n");
59+
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
60+
printf("evaluating prompt\n");
6161
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
6262
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
63+
printf("awaiting response\n");
6364

6465
// generate the response
6566

@@ -153,16 +154,14 @@ int main(int argc, char ** argv) {
153154
return 1;
154155
}
155156

156-
float * image_embd;
157-
int n_image_pos;
158-
load_image(ctx_llava, &params, &image_embd, &n_image_pos);
157+
auto image_embed = load_image(ctx_llava, &params);
159158

160159
// process the prompt
161-
process_prompt(ctx_llava, image_embd, n_image_pos, &params, params.prompt.c_str());
160+
process_prompt(ctx_llava, image_embed, &params, params.prompt.c_str());
162161

163162
llama_print_timings(ctx_llava->ctx_llama);
164163

165-
free(image_embd);
164+
llava_image_embed_free(image_embed);
166165
llava_free(ctx_llava);
167166
return 0;
168167
}

ggml-metal.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
189189
{
190190
NSBundle * bundle = nil;
191191
#ifdef SWIFT_PACKAGE
192+
print("would use SWIFTPM_MODULE_BUNDLE");
192193
bundle = SWIFTPM_MODULE_BUNDLE;
193194
#else
194195
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];

llava/clip.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
678678
return new_clip;
679679
}
680680

681-
clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
681+
clip_image_u8 * make_clip_image_u8() {
682+
auto img = new clip_image_u8();
683+
return img;
684+
}
682685
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }
683686

684687
void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; }
@@ -692,31 +695,30 @@ static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_
692695
memcpy(img->data, data, img->size);
693696
}
694697

695-
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img) {
698+
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
696699
int nx, ny, nc;
697-
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
700+
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
698701
if (!data) {
699-
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
702+
fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname);
700703
return false;
701704
}
702705
build_clip_img_from_data(data, nx, ny, img);
703706
stbi_image_free(data);
704707
return true;
705708
}
706709

707-
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
710+
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
708711
int nx, ny, nc;
709-
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
712+
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
710713
if (!data) {
711-
fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname);
714+
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
712715
return false;
713716
}
714717
build_clip_img_from_data(data, nx, ny, img);
715718
stbi_image_free(data);
716719
return true;
717720
}
718721

719-
720722
// normalize: x = (x - mean) / std
721723
// TODO: implement bicubic interpolation instead of linear.
722724
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
@@ -1065,16 +1067,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
10651067
return true;
10661068
}
10671069

1068-
int clip_n_mmproj_embd(struct clip_ctx * ctx) {
1070+
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
10691071
return ctx->vision_model.mm_2_b->ne[0];
10701072
}
10711073

1072-
int clip_n_patches(struct clip_ctx * ctx) {
1074+
int clip_n_patches(const struct clip_ctx * ctx) {
10731075
auto & params = ctx->vision_model.hparams;
10741076

10751077
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
10761078
}
10771079

1078-
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
1080+
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
10791081
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
10801082
}

llava/clip.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
2525

2626
void clip_free(struct clip_ctx * ctx);
2727

28-
size_t clip_embd_nbytes(struct clip_ctx * ctx);
29-
int clip_n_patches(struct clip_ctx * ctx);
30-
int clip_n_mmproj_embd(struct clip_ctx * ctx);
28+
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
29+
int clip_n_patches(const struct clip_ctx * ctx);
30+
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
3131

3232
// RGB uint8 image
3333
struct clip_image_u8 {
@@ -62,7 +62,7 @@ LLAMA_API void clip_image_u8_free(clip_image_u8 * img);
6262
LLAMA_API void clip_image_f32_free(clip_image_f32 * img);
6363
LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
6464
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
65-
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
65+
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
6666

6767
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
6868
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);

llava/llava-utils.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "common.h"
66
#include "llama.h"
7+
#include "llava.h"
78

89
#include "base64.hpp"
910

@@ -143,12 +144,12 @@ inline bool prompt_contains_image(const std::string& prompt) {
143144
}
144145

145146
// replaces the base64 image tag in the prompt with `replacement`
146-
inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8 * img) {
147+
inline llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
147148
size_t img_base64_str_start, img_base64_str_end;
148149
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
149150
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
150151
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
151-
return false;
152+
return NULL;
152153
}
153154

154155
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
@@ -157,16 +158,15 @@ inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8
157158

158159
auto required_bytes = base64::required_encode_size(base64_str.size());
159160
auto img_bytes = std::vector<unsigned char>(required_bytes);
160-
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
161-
size_t img_bytes_len = img_bytes_end - img_bytes.begin();
161+
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
162162

163-
auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
164-
if (!img_loaded_ok) {
163+
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
164+
if (!embed) {
165165
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
166-
return false;
166+
return NULL;
167167
}
168168

169-
return true;
169+
return embed;
170170
}
171171

172172
inline std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {

llava/llava.cpp

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include "base64.hpp"
1212

13-
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
13+
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
1414
clip_image_f32 * img_res = make_clip_image_f32();
1515
if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
1616
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
@@ -19,7 +19,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
1919
}
2020

2121
*n_img_pos = clip_n_patches(ctx_clip);
22-
*n_image_embd = clip_n_mmproj_embd(ctx_clip);
2322

2423
const int64_t t_img_enc_start_us = ggml_time_us();
2524
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
@@ -39,7 +38,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
3938
return true;
4039
}
4140

42-
bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
41+
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
42+
// make sure that the correct mmproj was used, i.e., compare apples to apples
43+
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
44+
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
45+
if (n_image_embd != n_llama_embd) {
46+
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
47+
return false;
48+
}
49+
return true;
50+
}
51+
52+
static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
4353

4454
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
4555
if (!image_embd) {
@@ -49,20 +59,11 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
4959
}
5060

5161
int n_img_pos;
52-
int n_image_embd;
53-
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_image_embd, &n_img_pos)) {
62+
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
5463
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
5564
free(image_embd);
5665
return false;
5766
}
58-
// make sure that the correct mmproj was used, i.e., compare apples to apples
59-
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
60-
if (n_image_embd != n_llama_embd) {
61-
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
62-
free(image_embd);
63-
return false;
64-
}
65-
6667
*image_embd_out = image_embd;
6768
*n_img_pos_out = n_img_pos;
6869

@@ -71,15 +72,15 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
7172

7273

7374

74-
bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) {
75+
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
7576
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
7677

77-
for (int i = 0; i < n_image_pos; i += n_batch) {
78-
int n_eval = n_image_pos - i;
78+
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
79+
int n_eval = image_embed->n_image_pos - i;
7980
if (n_eval > n_batch) {
8081
n_eval = n_batch;
8182
}
82-
llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
83+
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
8384
if (llama_decode(ctx_llama, batch)) {
8485
fprintf(stderr, "%s : failed to eval\n", __func__);
8586
return false;
@@ -88,3 +89,79 @@ bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_
8889
}
8990
return true;
9091
}
92+
93+
94+
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length)
95+
{
96+
clip_image_u8 * img = make_clip_image_u8();
97+
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
98+
clip_image_u8_free(img);
99+
fprintf(stderr, "%s: can't load image from bytes, is it a valid image?", __func__);
100+
return NULL;
101+
}
102+
103+
float* image_embed = NULL;
104+
int n_image_pos = 0;
105+
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
106+
if (!image_embed_result) {
107+
clip_image_u8_free(img);
108+
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
109+
return NULL;
110+
}
111+
112+
clip_image_u8_free(img);
113+
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
114+
result->embed = image_embed;
115+
result->n_image_pos = n_image_pos;
116+
return result;
117+
}
118+
119+
static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut)
120+
{
121+
auto file = fopen(path, "rb");
122+
if (file == NULL) {
123+
fprintf(stderr, "%s: can't read file %s\n", __func__, path);
124+
return false;
125+
}
126+
127+
fseek(file, 0, SEEK_END);
128+
auto fileSize = ftell(file);
129+
fseek(file, 0, SEEK_SET);
130+
131+
auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
132+
if (buffer == NULL) {
133+
fprintf(stderr, "%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
134+
perror("Memory allocation error");
135+
fclose(file);
136+
return false;
137+
}
138+
fread(buffer, 1, fileSize, file); // Read the file into the buffer
139+
fclose(file); // Close the file
140+
141+
*bytesOut = buffer;
142+
*sizeOut = fileSize;
143+
return true;
144+
145+
}
146+
147+
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path)
148+
{
149+
unsigned char* image_bytes;
150+
long image_bytes_length;
151+
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
152+
if (!loaded) {
153+
fprintf(stderr, "%s: failed to load %s\n", __func__, image_path);
154+
return NULL;
155+
}
156+
157+
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
158+
free(image_bytes);
159+
160+
return embed;
161+
}
162+
163+
164+
LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed) {
165+
free(embed->embed);
166+
free(embed);
167+
}

0 commit comments

Comments
 (0)