32
32
#endif
33
33
#endif
34
34
35
+ // TODO: Fix unused logit skipping crashes on ROCm
36
+ // (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
37
+ #ifndef LLAMA_USE_HIPBLAS
38
+ #define LLAMA_SKIP_UNUSED_LOGITS
39
+ #endif
40
+
35
41
#include <array>
36
42
#include <ctime>
37
43
#include <cinttypes>
@@ -1594,6 +1600,7 @@ static struct ggml_cgraph * llama_build_graph(
1594
1600
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
1595
1601
}
1596
1602
1603
+ #ifdef LLAMA_SKIP_UNUSED_LOGITS
1597
1604
if (il == n_layer - 1 && !lctx.logits_all)
1598
1605
{
1599
1606
// From here on, we only care about the last token and its logits.
@@ -1614,6 +1621,7 @@ static struct ggml_cgraph * llama_build_graph(
1614
1621
n_past += N - 1;
1615
1622
N = 1;
1616
1623
}
1624
+ #endif // LLAMA_SKIP_UNUSED_LOGITS
1617
1625
1618
1626
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
1619
1627
offload_func_kq(tmpq);
@@ -1920,9 +1928,14 @@ static bool llama_eval_internal(
1920
1928
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
1921
1929
} else {
1922
1930
// return result for just the last token
1923
- GGML_ASSERT(ggml_nelements(res) == n_vocab);
1924
1931
logits_out.resize(n_vocab);
1932
+ #ifdef LLAMA_SKIP_UNUSED_LOGITS
1933
+ GGML_ASSERT(ggml_nelements(res) == n_vocab);
1925
1934
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
1935
+ #else
1936
+ GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
1937
+ memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
1938
+ #endif
1926
1939
}
1927
1940
}
1928
1941
0 commit comments