32
32
#endif
33
33
#endif
34
34
35
+ // TODO: Fix unused logit skipping crashes on ROCm
36
+ // (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
37
+ #ifndef LLAMA_USE_HIPBLAS
38
+ #define LLAMA_SKIP_UNUSED_LOGITS
39
+ #endif
40
+
35
41
#include < array>
36
42
#include < ctime>
37
43
#include < cinttypes>
@@ -1594,6 +1600,7 @@ static struct ggml_cgraph * llama_build_graph(
1594
1600
ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
1595
1601
}
1596
1602
1603
+ #ifdef LLAMA_SKIP_UNUSED_LOGITS
1597
1604
if (il == n_layer - 1 && !lctx.logits_all )
1598
1605
{
1599
1606
// From here on, we only care about the last token and its logits.
@@ -1614,6 +1621,7 @@ static struct ggml_cgraph * llama_build_graph(
1614
1621
n_past += N - 1 ;
1615
1622
N = 1 ;
1616
1623
}
1624
+ #endif // LLAMA_SKIP_UNUSED_LOGITS
1617
1625
1618
1626
struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
1619
1627
offload_func_kq (tmpq);
@@ -1920,9 +1928,14 @@ static bool llama_eval_internal(
1920
1928
memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab*N);
1921
1929
} else {
1922
1930
// return result for just the last token
1923
- GGML_ASSERT (ggml_nelements (res) == n_vocab);
1924
1931
logits_out.resize (n_vocab);
1932
+ #ifdef LLAMA_SKIP_UNUSED_LOGITS
1933
+ GGML_ASSERT (ggml_nelements (res) == n_vocab);
1925
1934
memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab);
1935
+ #else
1936
+ GGML_ASSERT (ggml_nelements (res) == n_vocab * N);
1937
+ memcpy (logits_out.data (), (float *) ggml_get_data (res) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
1938
+ #endif
1926
1939
}
1927
1940
}
1928
1941
0 commit comments