Skip to content

Commit e3261ff

Browse files
committed
cleanup memory usage around clip_image_*
1 parent 2847ecf commit e3261ff

File tree

4 files changed

+43
-32
lines changed

4 files changed

+43
-32
lines changed

examples/llava/llava-cli.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,26 @@ static void show_additional_info(int /*argc*/, char ** argv) {
2222

2323
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
2424
// load and preprocess the image
25-
clip_image_u8 img;
25+
clip_image_u8 * img = make_clip_image_u8();
2626
auto prompt = params->prompt;
2727
if (prompt_contains_image(prompt)) {
2828
if (!params->image.empty()) {
2929
printf("using base64 encoded image instead of command line image path\n");
3030
}
31-
if (!clip_image_load_from_prompt(prompt, &img)) {
31+
if (!clip_image_load_from_prompt(prompt, img)) {
3232
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
3333
return false;
3434
}
3535
params->prompt = remove_image_from_prompt(prompt);
3636
} else {
37-
if (!clip_image_load_from_file(params->image.c_str(), &img)) {
37+
if (!clip_image_load_from_file(params->image.c_str(), img)) {
3838
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
3939
return false;
4040
}
4141
}
42-
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, &img, image_embd, n_img_pos);
42+
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos);
4343
if (!image_embed_result) {
44+
clip_image_u8_free(img);
4445
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
4546
return false;
4647
}

llava/clip.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -679,9 +679,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
679679
}
680680

681681
clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
682-
683682
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }
684683

684+
void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; }
685+
void clip_image_f32_free(clip_image_f32 * img) { if (img->data) { delete[] img->data; } delete img; }
686+
685687
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
686688
img->nx = nx;
687689
img->ny = ny;
@@ -726,39 +728,40 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
726728
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
727729
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
728730

729-
clip_image_u8 temp; // we will keep the input image data here temporarily
731+
clip_image_u8 * temp = make_clip_image_u8(); // we will keep the input image data here temporarily
730732
if (pad2square && img->nx != img->ny) {
731733
int longer_side = std::max(img->nx, img->ny);
732-
temp.nx = longer_side;
733-
temp.ny = longer_side;
734-
temp.size = 3 * longer_side * longer_side;
735-
temp.data = new uint8_t[temp.size]();
734+
temp->nx = longer_side;
735+
temp->ny = longer_side;
736+
temp->size = 3 * longer_side * longer_side;
737+
temp->data = new uint8_t[temp->size]();
736738
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA
737739

738740
// fill with background color
739-
for (size_t i = 0; i < temp.size; i++) {
740-
temp.data[i] = bc[i % 3];
741+
for (size_t i = 0; i < temp->size; i++) {
742+
temp->data[i] = bc[i % 3];
741743
}
742744

743745
// copy from the input image
744746
for (int y = 0; y < img->ny; y++) {
745747
for (int x = 0; x < img->nx; x++) {
746748
const int i = 3 * (y * img->nx + x);
747-
const int j = 3 * (y * temp.nx + x);
748-
temp.data[j] = img->data[i];
749-
temp.data[j+1] = img->data[i+1];
750-
temp.data[j+2] = img->data[i+2];
749+
const int j = 3 * (y * temp->nx + x);
750+
temp->data[j] = img->data[i];
751+
temp->data[j+1] = img->data[i+1];
752+
temp->data[j+2] = img->data[i+2];
751753
}
752754
}
753755
} else {
754-
temp.nx = img->nx;
755-
temp.ny = img->ny;
756-
temp.size = img->size;
757-
temp.data = img->data;
756+
temp->nx = img->nx;
757+
temp->ny = img->ny;
758+
temp->size = img->size;
759+
temp->data = new uint8_t[temp->size]();
760+
*temp->data = *img->data; // copy
758761
}
759762

760-
const int nx = temp.nx;
761-
const int ny = temp.ny;
763+
const int nx = temp->nx;
764+
const int ny = temp->ny;
762765

763766
const int nx2 = ctx->vision_model.hparams.image_size;
764767
const int ny2 = ctx->vision_model.hparams.image_size;
@@ -797,10 +800,10 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
797800
const int j10 = 3 * (y1 * nx + x0) + c;
798801
const int j11 = 3 * (y1 * nx + x1) + c;
799802

800-
const float v00 = temp.data[j00];
801-
const float v01 = temp.data[j01];
802-
const float v10 = temp.data[j10];
803-
const float v11 = temp.data[j11];
803+
const float v00 = temp->data[j00];
804+
const float v01 = temp->data[j01];
805+
const float v10 = temp->data[j10];
806+
const float v11 = temp->data[j11];
804807

805808
const float v0 = v00 * (1.0f - dx) + v01 * dx;
806809
const float v1 = v10 * (1.0f - dx) + v11 * dx;
@@ -815,6 +818,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
815818
}
816819
}
817820
}
821+
clip_image_u8_free(temp);
818822

819823
return true;
820824
}

llava/clip.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ int clip_n_mmproj_embd(struct clip_ctx * ctx);
3333
struct clip_image_u8 {
3434
int nx;
3535
int ny;
36-
uint8_t * data;
36+
uint8_t * data = NULL;
3737
size_t size;
3838
};
3939

@@ -42,7 +42,7 @@ struct clip_image_u8 {
4242
struct clip_image_f32 {
4343
int nx;
4444
int ny;
45-
float * data;
45+
float * data = NULL;
4646
size_t size;
4747
};
4848

@@ -58,8 +58,12 @@ struct clip_image_f32_batch {
5858

5959
struct clip_image_u8 * make_clip_image_u8();
6060
struct clip_image_f32 * make_clip_image_f32();
61+
LLAMA_API void clip_image_u8_free(clip_image_u8 * img);
62+
LLAMA_API void clip_image_f32_free(clip_image_f32 * img);
6163
LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
64+
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
6265
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
66+
6367
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
6468
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
6569

llava/llava.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111
#include "base64.hpp"
1212

1313
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
14-
clip_image_f32 img_res;
15-
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) {
14+
clip_image_f32 * img_res = make_clip_image_f32();
15+
if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
1616
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
17-
17+
clip_image_f32_free(img_res);
1818
return false;
1919
}
2020

2121
*n_img_pos = clip_n_patches(ctx_clip);
2222
*n_image_embd = clip_n_mmproj_embd(ctx_clip);
2323

2424
const int64_t t_img_enc_start_us = ggml_time_us();
25-
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) {
25+
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
26+
clip_image_f32_free(img_res);
27+
if (!encoded) {
2628
fprintf(stderr, "Unable to encode image\n");
2729

2830
return false;

0 commit comments

Comments
 (0)