Skip to content

Commit 854ae5d

Browse files
committed
metal : temporary workaround for the concurrency optimization bug
1 parent 0a85ae7 commit 854ae5d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

llama.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,9 +2333,11 @@ static struct ggml_cgraph * llm_build_llama(
23332333
inpL = cur;
23342334
}
23352335

2336+
cur = inpL;
2337+
23362338
// norm
23372339
{
2338-
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
2340+
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
23392341
offload_func_nr(cur);
23402342
ggml_set_name(cur, "rms_norm_2");
23412343

@@ -2436,7 +2438,6 @@ static struct ggml_cgraph * llm_build_falcon(
24362438
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
24372439

24382440
for (int il = 0; il < n_layer; ++il) {
2439-
struct ggml_tensor * cur;
24402441
struct ggml_tensor * attn_norm;
24412442

24422443
// self-attention
@@ -2561,6 +2562,12 @@ static struct ggml_cgraph * llm_build_falcon(
25612562
struct ggml_tensor * inpFF = attn_norm;
25622563

25632564
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
2565+
2566+
// TODO: this is temporary needed to introduce artificial dependency between FF and ATTN
2567+
// adding this, because there seems to be a bug in the Metal concurrency optimization
2568+
// without this line, the results are non-deterministic and wrong
2569+
cur->src[2] = attn_out;
2570+
25642571
cur = ggml_gelu(ctx0, cur);
25652572
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
25662573
}
@@ -2572,9 +2579,11 @@ static struct ggml_cgraph * llm_build_falcon(
25722579
inpL = cur;
25732580
}
25742581

2582+
cur = inpL;
2583+
25752584
// norm
25762585
{
2577-
cur = ggml_norm(ctx0, inpL, norm_eps);
2586+
cur = ggml_norm(ctx0, cur, norm_eps);
25782587

25792588
cur = ggml_add(ctx0,
25802589
ggml_mul(ctx0, cur, model.output_norm),

0 commit comments

Comments
 (0)