Skip to content

Commit afefa31

Browse files
authored
ggml : change ggml_scale to take a float instead of tensor (#4573)
* ggml : change ggml_scale to take a float instead of tensor * ggml : fix CPU implementation * tests : fix test-grad0 ggml-ci
1 parent 769a7bc commit afefa31

File tree

12 files changed

+81
-204
lines changed

12 files changed

+81
-204
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,7 @@ static struct ggml_tensor * forward(
575575

576576
// KQ_scaled = KQ / sqrt(n_embd/n_head)
577577
// KQ_scaled shape [n_past + N, N, n_head, 1]
578-
struct ggml_tensor * KQ_scaled =
579-
ggml_scale(ctx0,
580-
KQ,
581-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
578+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
582579

583580
// KQ_masked = mask_past(KQ_scaled)
584581
// KQ_masked shape [n_past + N, N, n_head, 1]
@@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(
844841

845842
// KQ_scaled = KQ / sqrt(n_embd/n_head)
846843
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
847-
struct ggml_tensor * KQ_scaled =
848-
ggml_scale(ctx0,
849-
KQ,
850-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
844+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
851845
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
852846

853847
// KQ_masked = mask_past(KQ_scaled)
@@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(
11311125

11321126
// KQ_scaled = KQ / sqrt(n_embd/n_head)
11331127
// KQ_scaled shape [n_past + N, N, n_head, 1]
1134-
struct ggml_tensor * KQ_scaled =
1135-
ggml_scale(ctx0,
1136-
KQ,
1137-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
1128+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
11381129

11391130
// KQ_masked = mask_past(KQ_scaled)
11401131
// KQ_masked shape [n_past + N, N, n_head, 1]

examples/export-lora/export-lora.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
309309
) {
310310
struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
311311
if (scaling != 1.0f) {
312-
ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
312+
ab = ggml_scale(ctx, ab, scaling);
313313
}
314314
struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
315315

examples/finetune/finetune.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
269269
float rope_freq_scale = 1.0f;
270270
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
271271
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
272-
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
272+
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
273273
if (rope_freq_scale != 1.0f) {
274274
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
275275
}
@@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
612612
const int n_rot = hparams.n_embd_head();
613613
const int n_embd_head = hparams.n_embd_head();
614614
const int n_embd_gqa = hparams.n_embd_gqa();
615+
615616
const float rms_norm_eps = hparams.f_norm_rms_eps;
616617
const float rope_freq_base = hparams.rope_freq_base;
617618
const float rope_freq_scale = hparams.rope_freq_scale;
@@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
680681
checkpoints.push_back(t01);
681682
}
682683

683-
struct ggml_tensor * kv_scale = NULL;
684-
if (!enable_flash_attn) {
685-
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
686-
}
684+
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
687685

688686
for (int il = 0; il < n_layer; ++il) {
689687
struct my_llama_layer & layer = model->layers[il];
@@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
781779
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
782780
int n_leafs_before = gb->n_leafs;
783781
int n_nodes_before = gb->n_nodes;
784-
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
782+
785783
// output tensors
786-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
787-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
784+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
785+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
788786
// input gradient
789-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
787+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
790788
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
791789
ggml_allocr_alloc(alloc, t36->grad);
792790
// KQ_pos
793-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
791+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
794792

795793
// make sure base model tensors data cannot be used in viewable operations
796-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
797-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
798-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
794+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
795+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
796+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
799797
for (int il = 0; il < n_layer; ++il) {
800798
struct my_llama_layer & layer = model->layers[il];
801-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
802-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
803-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
804-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
805-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
806-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
807-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
808-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
809-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
799+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
800+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
801+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
802+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
803+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
804+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
805+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
806+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
807+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
810808
}
811809

812810
// allocating checkpoints in one block to reduce memory fragmentation

examples/llava/clip.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
330330
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
331331
}
332332

333-
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
334-
ggml_allocr_alloc(ctx->alloc, KQ_scale);
335-
if (!ggml_allocr_is_measure(ctx->alloc)) {
336-
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
337-
}
338-
339333
// loop over layers
340334
for (int il = 0; il < n_layer - 1; il++) {
341335
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
356350
struct ggml_tensor * Q =
357351
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
358352

359-
Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
353+
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
360354
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
361355
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
362356
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
369369
checkpoints.push_back(t00);
370370
checkpoints.push_back(t01);
371371

372-
struct ggml_tensor * kv_scale = NULL;
373-
if (!enable_flash_attn) {
374-
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
375-
}
372+
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
376373

377374
for (int il = 0; il < n_layer; ++il) {
378375
struct my_llama_layer & layer = model->layers[il];
@@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
444441
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
445442
int n_leafs_before = gb->n_leafs;
446443
int n_nodes_before = gb->n_nodes;
447-
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
448444
// output tensors
449-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
450-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
445+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
446+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
451447
// input gradient
452-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
448+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
453449
// KQ_pos
454-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
450+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
455451
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
456452

457453
ggml_allocr_alloc(alloc, t36->grad);

ggml-cuda.cu

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale(
77007700
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
77017701

77027702
GGML_ASSERT(src0->type == GGML_TYPE_F32);
7703-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
77047703
GGML_ASSERT( dst->type == GGML_TYPE_F32);
77057704

7706-
float scale;
7707-
// HACK: support for ggml backend interface
7708-
if (src1->backend == GGML_BACKEND_CPU) {
7709-
scale = ((float *) src1->data)[0];
7710-
} else {
7711-
// TODO: pass pointer to kernel instead of copying to host
7712-
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
7713-
}
7705+
const float scale = ((float *) dst->op_params)[0];
77147706

77157707
scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
77167708
CUDA_CHECK(cudaGetLastError());
@@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
77577749
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
77587750
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
77597751

7760-
const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
7761-
77627752
// dd = data device
77637753
float * src0_ddf = nullptr;
77647754
float * src1_ddf = nullptr;
@@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
77797769
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
77807770
}
77817771

7782-
if (use_src1 && !src1_stays_on_host) {
7772+
if (use_src1) {
77837773
if (src1_on_device) {
77847774
src1_ddf = (float *) src1_extra->data_device[g_main_device];
77857775
} else {

ggml-metal.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
12931293
{
12941294
GGML_ASSERT(ggml_is_contiguous(src0));
12951295

1296-
const float scale = *(const float *) src1->data;
1296+
const float scale = *(const float *) dst->op_params;
12971297

12981298
int64_t n = ggml_nelements(dst);
12991299

@@ -1304,8 +1304,8 @@ void ggml_metal_graph_compute(
13041304
[encoder setComputePipelineState:ctx->pipeline_scale];
13051305
}
13061306

1307-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1308-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1307+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1308+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
13091309
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
13101310

13111311
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

ggml.c

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4171,39 +4171,39 @@ struct ggml_tensor * ggml_out_prod(
41714171
static struct ggml_tensor * ggml_scale_impl(
41724172
struct ggml_context * ctx,
41734173
struct ggml_tensor * a,
4174-
struct ggml_tensor * b,
4174+
float s,
41754175
bool inplace) {
4176-
GGML_ASSERT(ggml_is_scalar(b));
41774176
GGML_ASSERT(ggml_is_padded_1d(a));
41784177

41794178
bool is_node = false;
41804179

4181-
if (a->grad || b->grad) {
4180+
if (a->grad) {
41824181
is_node = true;
41834182
}
41844183

41854184
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
41864185

4186+
ggml_set_op_params(result, &s, sizeof(s));
4187+
41874188
result->op = GGML_OP_SCALE;
41884189
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
41894190
result->src[0] = a;
4190-
result->src[1] = b;
41914191

41924192
return result;
41934193
}
41944194

41954195
struct ggml_tensor * ggml_scale(
41964196
struct ggml_context * ctx,
41974197
struct ggml_tensor * a,
4198-
struct ggml_tensor * b) {
4199-
return ggml_scale_impl(ctx, a, b, false);
4198+
float s) {
4199+
return ggml_scale_impl(ctx, a, s, false);
42004200
}
42014201

42024202
struct ggml_tensor * ggml_scale_inplace(
42034203
struct ggml_context * ctx,
42044204
struct ggml_tensor * a,
4205-
struct ggml_tensor * b) {
4206-
return ggml_scale_impl(ctx, a, b, true);
4205+
float s) {
4206+
return ggml_scale_impl(ctx, a, s, true);
42074207
}
42084208

42094209
// ggml_set
@@ -10325,19 +10325,17 @@ static void ggml_compute_forward_out_prod(
1032510325
static void ggml_compute_forward_scale_f32(
1032610326
const struct ggml_compute_params * params,
1032710327
const struct ggml_tensor * src0,
10328-
const struct ggml_tensor * src1,
1032910328
struct ggml_tensor * dst) {
1033010329
GGML_ASSERT(ggml_is_contiguous(src0));
1033110330
GGML_ASSERT(ggml_is_contiguous(dst));
1033210331
GGML_ASSERT(ggml_are_same_shape(src0, dst));
10333-
GGML_ASSERT(ggml_is_scalar(src1));
1033410332

1033510333
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1033610334
return;
1033710335
}
1033810336

1033910337
// scale factor
10340-
const float v = *(float *) src1->data;
10338+
const float v = *(float *) dst->op_params;
1034110339

1034210340
const int ith = params->ith;
1034310341
const int nth = params->nth;
@@ -10368,12 +10366,11 @@ static void ggml_compute_forward_scale_f32(
1036810366
static void ggml_compute_forward_scale(
1036910367
const struct ggml_compute_params * params,
1037010368
const struct ggml_tensor * src0,
10371-
const struct ggml_tensor * src1,
1037210369
struct ggml_tensor * dst) {
1037310370
switch (src0->type) {
1037410371
case GGML_TYPE_F32:
1037510372
{
10376-
ggml_compute_forward_scale_f32(params, src0, src1, dst);
10373+
ggml_compute_forward_scale_f32(params, src0, dst);
1037710374
} break;
1037810375
default:
1037910376
{
@@ -14383,7 +14380,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1438314380
} break;
1438414381
case GGML_OP_SCALE:
1438514382
{
14386-
ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14383+
ggml_compute_forward_scale(params, tensor->src[0], tensor);
1438714384
} break;
1438814385
case GGML_OP_SET:
1438914386
{
@@ -14839,7 +14836,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
1483914836

1484014837
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
1484114838
if (ggml_hash_contains(zero_table, a)) {
14842-
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
14839+
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
1484314840
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
1484414841
} else {
1484514842
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
@@ -14975,7 +14972,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1497514972
src0->grad,
1497614973
ggml_scale(ctx,
1497714974
ggml_mul(ctx, src0, tensor->grad),
14978-
ggml_new_f32(ctx, 2.0f)),
14975+
2.0f),
1497914976
zero_table);
1498014977
}
1498114978
} break;
@@ -14989,7 +14986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1498914986
ggml_div(ctx,
1499014987
tensor->grad,
1499114988
tensor),
14992-
ggml_new_f32(ctx, 0.5f)),
14989+
0.5f),
1499314990
zero_table);
1499414991
}
1499514992
} break;
@@ -15155,17 +15152,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1515515152
{
1515615153
// necessary for llama
1515715154
if (src0->grad) {
15155+
const float s = ((float *) tensor->op_params)[0];
15156+
1515815157
src0->grad =
1515915158
ggml_add_or_set(ctx,
1516015159
src0->grad,
15161-
ggml_scale_impl(ctx, tensor->grad, src1, false),
15162-
zero_table);
15163-
}
15164-
if (src1->grad) {
15165-
src1->grad =
15166-
ggml_add_or_set(ctx,
15167-
src1->grad,
15168-
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
15160+
ggml_scale_impl(ctx, tensor->grad, s, false),
1516915161
zero_table);
1517015162
}
1517115163
} break;

ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,13 +1094,13 @@ extern "C" {
10941094
GGML_API struct ggml_tensor * ggml_scale(
10951095
struct ggml_context * ctx,
10961096
struct ggml_tensor * a,
1097-
struct ggml_tensor * b);
1097+
float s);
10981098

10991099
// in-place, returns view(a)
11001100
GGML_API struct ggml_tensor * ggml_scale_inplace(
11011101
struct ggml_context * ctx,
11021102
struct ggml_tensor * a,
1103-
struct ggml_tensor * b);
1103+
float s);
11041104

11051105
// b -> view(a,offset,nb1,nb2,3), return modified a
11061106
GGML_API struct ggml_tensor * ggml_set(

0 commit comments

Comments
 (0)