@@ -1812,6 +1812,12 @@ static bool llama_eval_internal(
18121812 // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
18131813 n_threads = N >= 32 && ggml_cpu_has_blas () && !ggml_cpu_has_gpublas () ? 1 : n_threads;
18141814
1815+ struct ggml_tensor * res = gf->nodes [gf->n_nodes - 1 ];
1816+ struct ggml_tensor * embeddings = gf->nodes [gf->n_nodes - 2 ];
1817+
1818+ LLAMA_ASSERT (strcmp (res->name , " result_output" ) == 0 );
1819+ LLAMA_ASSERT (strcmp (embeddings->name , " result_norm" ) == 0 );
1820+
18151821#if GGML_USE_MPI
18161822 const int64_t n_layer = hparams.n_layer ;
18171823 ggml_mpi_graph_compute_pre (lctx.ctx_mpi , gf, n_layer);
@@ -1825,7 +1831,10 @@ static bool llama_eval_internal(
18251831 // }
18261832 ggml_metal_set_n_cb (lctx.ctx_metal , n_threads);
18271833 ggml_metal_graph_compute (lctx.ctx_metal , gf);
1828- ggml_metal_get_tensor (lctx.ctx_metal , cur);
1834+ ggml_metal_get_tensor (lctx.ctx_metal , res);
1835+ if (!lctx.embedding .empty ()) {
1836+ ggml_metal_get_tensor (lctx.ctx_metal , embeddings);
1837+ }
18291838 } else {
18301839 // IMPORTANT:
18311840 // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
@@ -1856,12 +1865,6 @@ static bool llama_eval_internal(
18561865 // update kv token count
18571866 lctx.kv_self .n = n_past + N;
18581867
1859- struct ggml_tensor * res = gf->nodes [gf->n_nodes - 1 ];
1860- struct ggml_tensor * embeddings = gf->nodes [gf->n_nodes - 2 ];
1861-
1862- LLAMA_ASSERT (strcmp (res->name , " result_output" ) == 0 );
1863- LLAMA_ASSERT (strcmp (embeddings->name , " result_norm" ) == 0 );
1864-
18651868 if (cgraph_fname) {
18661869 ggml_graph_export (gf, cgraph_fname);
18671870 }
0 commit comments