Skip to content

Commit 8d4a855

Browse files
Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent b6b268d commit 8d4a855

File tree

5 files changed

+82
-10
lines changed

5 files changed

+82
-10
lines changed

llama.cpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ struct llama_context {
102102
// decode output (2-dimensional array: [n_tokens][n_vocab])
103103
std::vector<float> logits;
104104
bool logits_all = false;
105+
106+
// input embedding (1-dimensional array: [n_embd])
107+
std::vector<float> embedding;
105108
};
106109

107110
struct llama_context_params llama_context_default_params() {
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
112115
/*.f16_kv =*/ false,
113116
/*.logits_all =*/ false,
114117
/*.vocab_only =*/ false,
118+
/*.embedding =*/ false,
115119
};
116120

117121
return result;
@@ -592,8 +596,6 @@ static bool llama_model_load(
592596
fin.close();
593597
}
594598

595-
lctx.logits.reserve(lctx.model.hparams.n_ctx);
596-
597599
lctx.t_load_us = ggml_time_us() - t_start_us;
598600

599601
return true;
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
791793
inpL = cur;
792794
}
793795

796+
// used at the end to optionally extract the embeddings
797+
struct ggml_tensor * embeddings = NULL;
798+
794799
// norm
795800
{
796801
inpL = ggml_rms_norm(ctx0, inpL);
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
799804
inpL = ggml_mul(ctx0,
800805
ggml_repeat(ctx0, model.norm, inpL),
801806
inpL);
807+
808+
embeddings = inpL;
802809
}
803810

804811
// lm_head
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
821828
//embd_w.resize(n_vocab*N);
822829
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
823830

824-
auto & logits_out = lctx.logits;
831+
// extract logits
832+
{
833+
auto & logits_out = lctx.logits;
834+
835+
if (lctx.logits_all) {
836+
logits_out.resize(n_vocab * N);
837+
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
838+
} else {
839+
// return result for just the last token
840+
logits_out.resize(n_vocab);
841+
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
842+
}
843+
}
844+
845+
// extract embeddings
846+
if (lctx.embedding.size()) {
847+
auto & embedding_out = lctx.embedding;
825848

826-
if (lctx.logits_all) {
827-
logits_out.resize(n_vocab * N);
828-
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
829-
} else {
830-
// return result for just the last token
831-
logits_out.resize(n_vocab);
832-
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
849+
embedding_out.resize(n_embd);
850+
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
833851
}
834852

835853
if (mem_per_token == 0) {
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
14161434
return nullptr;
14171435
}
14181436

1437+
// reserve memory for context buffers
1438+
{
1439+
const auto & hparams = ctx->model.hparams;
1440+
if (params.logits_all) {
1441+
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
1442+
} else {
1443+
ctx->logits.reserve(hparams.n_ctx);
1444+
}
1445+
1446+
if (params.embedding){
1447+
ctx->embedding.reserve(hparams.n_embd);
1448+
}
1449+
}
1450+
14191451
return ctx;
14201452
}
14211453

@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
14841516
return ctx->logits.data();
14851517
}
14861518

1519+
float * llama_get_embeddings(struct llama_context * ctx) {
1520+
return ctx->embedding.data();
1521+
}
1522+
14871523
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
14881524
if (token >= llama_n_vocab(ctx)) {
14891525
return nullptr;

llama.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ extern "C" {
5353
bool f16_kv; // use fp16 for KV cache
5454
bool logits_all; // the llama_eval() call computes all logits, not just the last one
5555
bool vocab_only; // only load the vocabulary, no weights
56+
bool embedding; // embedding mode only
5657
};
5758

5859
LLAMA_API struct llama_context_params llama_context_default_params();
@@ -108,6 +109,10 @@ extern "C" {
108109
// Cols: n_vocab
109110
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
110111

112+
// Get the embeddings for the input
113+
// shape: [n_embd] (1-dimensional)
114+
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
115+
111116
// Token Id -> String. Uses the vocabulary in the provided context
112117
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
113118

main.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
199199
lparams.seed = params.seed;
200200
lparams.f16_kv = params.memory_f16;
201201
lparams.logits_all = params.perplexity;
202+
lparams.embedding = params.embedding;
202203

203204
ctx = llama_init_from_file(params.model.c_str(), lparams);
204205

@@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
292293

293294
std::vector<llama_token> embd;
294295

296+
295297
int last_n_size = params.repeat_last_n;
296298
std::vector<llama_token> last_n_tokens(last_n_size);
297299
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
324326
// the first thing we will do is to output the prompt, so set color accordingly
325327
set_console_state(CONSOLE_STATE_PROMPT);
326328

329+
if (params.embedding){
330+
embd = embd_inp;
331+
332+
if (embd.size() > 0) {
333+
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
334+
fprintf(stderr, "%s : failed to eval\n", __func__);
335+
return 1;
336+
}
337+
}
338+
339+
const auto embeddings = llama_get_embeddings(ctx);
340+
341+
// TODO: print / use the embeddings
342+
343+
if (params.use_color) {
344+
printf(ANSI_COLOR_RESET);
345+
}
346+
347+
return 0;
348+
}
349+
327350
while (remaining_tokens > 0 || params.interactive) {
328351
// predict
329352
if (embd.size() > 0) {

utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
117117
params.model = argv[i];
118118
} else if (arg == "-i" || arg == "--interactive") {
119119
params.interactive = true;
120+
} else if (arg == "--embedding") {
121+
params.embedding = true;
122+
} else if (arg == "--interactive-start") {
123+
params.interactive = true;
120124
} else if (arg == "--interactive-first") {
121125
params.interactive_start = true;
122126
} else if (arg == "-ins" || arg == "--instruct") {

utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ struct gpt_params {
3232
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3333
std::string prompt = "";
3434

35+
3536
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
3637

3738
bool memory_f16 = false; // use f16 instead of f32 for memory kv
3839
bool random_prompt = false; // do not randomize prompt if none provided
3940
bool use_color = false; // use color to distinguish generations and inputs
4041
bool interactive = false; // interactive mode
42+
43+
bool embedding = false; // get only sentence embedding
4144
bool interactive_start = false; // wait for user input immediately
45+
4246
bool instruct = false; // instruction mode (used for Alpaca models)
4347
bool ignore_eos = false; // do not stop generating after eos
4448
bool perplexity = false; // compute perplexity over the prompt

0 commit comments

Comments
 (0)