Skip to content

Commit 51755b7

Browse files
author
Judd
committed
Cohere: use logit_scale just as in CohereForCausalLM.
1 parent 7f43063 commit 51755b7

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

models.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ namespace chatllm
269269
BaseModelForConditionalGeneration(ModelType model_type, BaseConfig config, size_t mem_size, size_t scratch_size)
270270
: BaseModel(model_type, to_string(model_type), to_native_string(model_type), get_model_purpose(model_type)),
271271
GRAPH_SIZE(GGML_DEFAULT_GRAPH_SIZE),
272-
batch_input(true),
272+
batch_input(true), logit_scale(-1.0f),
273273
config_(config), mem_size_(mem_size), mem_buffer_(new char[mem_size]),
274274
scratch_size_(scratch_size), scratch_buffer_(new char[scratch_size])
275275
{
@@ -485,6 +485,9 @@ namespace chatllm
485485

486486
ggml_tensor *r = transformer.forward(&ctx, input_ids_tensor, past);
487487

488+
if (logit_scale > 0)
489+
r = ggml_scale_inplace(ctx.gctx.get(), r, logit_scale);
490+
488491
ggml_build_forward_expand(ctx.gf, r);
489492
ggml_graph_compute_with_ctx(ctx.gctx.get(), ctx.gf, n_threads);
490493

@@ -557,6 +560,7 @@ namespace chatllm
557560
LM transformer;
558561
size_t GRAPH_SIZE;
559562
bool batch_input;
563+
float logit_scale;
560564
private:
561565
BaseConfig config_;
562566
size_t mem_size_;

models/cohere.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration<
9393
attention.freq_base = config.rope_theta;
9494
}
9595

96+
logit_scale = config.logit_scale;
97+
9698
GRAPH_SIZE = 4096;
9799
}
98100

0 commit comments

Comments
 (0)