Skip to content

Commit 6916ed1

Browse files
committed
llama : aboud ggml_repeat during classification
1 parent 6235c62 commit 6916ed1

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/llama.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10243,9 +10243,6 @@ struct llm_build_context {
1024310243
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
1024410244
cur = ggml_tanh(ctx0, cur);
1024510245
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
10246-
10247-
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
10248-
cur = ggml_repeat(ctx0, cur, inp);
1024910246
} break;
1025010247
default:
1025110248
{
@@ -16997,7 +16994,6 @@ static int llama_decode_internal(
1699716994
case LLAMA_POOLING_TYPE_MEAN:
1699816995
case LLAMA_POOLING_TYPE_CLS:
1699916996
case LLAMA_POOLING_TYPE_LAST:
17000-
case LLAMA_POOLING_TYPE_RANK:
1700116997
{
1700216998
// extract sequence embeddings (cleared before processing each batch)
1700316999
auto & embd_seq_out = lctx.embd_seq;
@@ -17011,6 +17007,20 @@ static int llama_decode_internal(
1701117007
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1701217008
}
1701317009
} break;
17010+
case LLAMA_POOLING_TYPE_RANK:
17011+
{
17012+
// extract the rank score - a single float per sequence
17013+
auto & embd_seq_out = lctx.embd_seq;
17014+
17015+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
17016+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
17017+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
17018+
continue;
17019+
}
17020+
embd_seq_out[seq_id].resize(1);
17021+
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
17022+
}
17023+
} break;
1701417024
case LLAMA_POOLING_TYPE_UNSPECIFIED:
1701517025
{
1701617026
GGML_ABORT("unknown pooling type");

0 commit comments

Comments
 (0)