-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Bug: Unexpected output from Granite 3.0 MoE 1b when all layers on NVIDIA GPU #9991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
When following these steps, the conversion fails due to a missing |
Thanks for the quick check! You should have a |
I see ( tree granite-3.0-1b-a400m-instruct/
granite-3.0-1b-a400m-instruct/
├── README.md
├── added_tokens.json
├── config.json
├── generation_config.json
├── ggml-model-Q4_K_M.gguf
├── granite-3.0-1B-a400M-instruct-F16.gguf
├── granite-3.0-1B-a400M-instruct-F32.gguf
├── merges.txt
├── model.safetensors
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.json |
My bad, I was using an outdated version of hf-transformers, updating it fixed it. I was able to convert the model and reproduce the issue now. |
Looks like this model requires increased precision in the attention. Using diff --git a/src/llama.cpp b/src/llama.cpp
index 98ec123c..66d52fe2 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9580,20 +9580,16 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
- }
+ ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
- // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
- // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- }
+ // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+ // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (model.arch == LLM_ARCH_GROK) {
// need to do the following: |
Great! I'm off for the day, but will dig more tomorrow. I appreciate any insight you uncover. |
It's a fairly common problem, the CUDA backend does matrix multiplications with F16 precision by default. Generally this is ok, but with some models this causes issues in the KQ matrix multiplication. Either enabling flash attention with |
If we change the default for KV precision we should maybe also change the default for FlashAttention. The llama.cpp FlashAttention implementation should have largely the same performance regardless of whether FP16 of FP32 precision is used since all that changes is the KQ accumulator type and the precision for softmax. But for cuBLAS GEMM it is I think not possible to use different data types for inputs and outputs even though the tensor core hardware would in principle support that. So for long contexts I would expect the performance without Flashattention to degrade. |
Yes, it's time to change the default precision for this operation as the issues keep piling up. Regarding FA - since it is not generally supported by all backends, such as SYCL and Vulkan, and the performance on CPU is likely worse compared to no-FA because the FA CPU implementation is mainly used as a reference, it probably has to remain as opt-in for now? Although it does seems to me that developers using |
Would it make sense to check |
Yes, we should enable flash attention by default, and automatically disable it for the layers where the backend does not support it using the |
Uh oh!
There was an error while loading. Please reload this page.
What happened?
Overview
As we launched the
Granite 3.0
models today, we have found that one of them, the1b-a400m
MoE model, behaves very strangely if all layers are placed on an NVIDIA GPU. If we keep the last two layers off the GPU, the model performs as expected.I'm opening this ticket to track my own work investigating the issue as well as see if it happens to trigger any thoughts for others in the community.
Details
We originally noticed this via the
ollama
integration, but I've been able to repro it withllama-cli
directly, so I'm trying to isolate the issue here. I've done the following investigation so far:F32
,F16
, andQ4_K_M
variants with no change, so this doesn't appear to be related todtype
or quantization.3b-a800m
) and see none of the same behavior with this model. It has the same architecture, but subtly different parameters:--split-mode none
) does not change the resultsExperiments
When run with
llama-cli -m <model> -ngl 23 -p "hi" -n 10
the output is sane:logs
``` llama-cli -m granite-3.0-1b-a400m-instruct/granite-3.0-1B-a400M-instruct-F32.gguf -ngl 23 -p "hi" -n 10 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 8 CUDA devices: Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 1: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 2: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 3: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 4: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 5: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 6: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes Device 7: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes build: 3953 (994cfb1) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu main: llama backend init main: load the model and apply lora adapter, if any llama_load_model_from_file: using device CUDA0 (NVIDIA A100-SXM4-80GB) - 80732 MiB free llama_load_model_from_file: using device CUDA1 (NVIDIA A100-SXM4-80GB) - 80732 MiB free llama_load_model_from_file: using device CUDA2 (NVIDIA A100-SXM4-80GB) - 80732 MiB free llama_load_model_from_file: using device CUDA3 (NVIDIA A100-SXM4-80GB) - 80732 MiB free llama_load_model_from_file: using device CUDA4 (NVIDIA A100-SXM4-80GB) - 67189 MiB free llama_load_model_from_file: using device CUDA5 (NVIDIA A100-SXM4-80GB) - 68997 MiB free llama_load_model_from_file: using device CUDA6 (NVIDIA A100-SXM4-80GB) - 29537 MiB free llama_load_model_from_file: using device CUDA7 (NVIDIA A100-SXM4-80GB) - 70903 MiB free llama_model_loader: loaded meta data with 38 key-value pairs and 242 tensors from granite-3.0-1b-a400m-instruct/granite-3.0-1B-a400M-instruct-F32.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = granitemoe llama_model_loader: - kv 1: general.type str = model llama_model_loader: - kv 2: general.name str = Granite 3.0 1b A400M Instruct llama_model_loader: - kv 3: general.finetune str = instruct llama_model_loader: - kv 4: general.basename str = granite-3.0 llama_model_loader: - kv 5: general.size_label str = 1B-a400M llama_model_loader: - kv 6: general.license str = apache-2.0 llama_model_loader: - kv 7: general.tags arr[str,3] = ["language", "granite-3.0", "text-gen... llama_model_loader: - kv 8: granitemoe.block_count u32 = 24 llama_model_loader: - kv 9: granitemoe.context_length u32 = 4096 llama_model_loader: - kv 10: granitemoe.embedding_length u32 = 1024 llama_model_loader: - kv 11: granitemoe.feed_forward_length u32 = 512 llama_model_loader: - kv 12: granitemoe.attention.head_count u32 = 16 llama_model_loader: - kv 13: granitemoe.attention.head_count_kv u32 = 8 llama_model_loader: - kv 14: granitemoe.rope.freq_base f32 = 10000.000000 llama_model_loader: - kv 15: granitemoe.attention.layer_norm_rms_epsilon f32 = 0.000001 llama_model_loader: - kv 16: granitemoe.expert_count u32 = 32 llama_model_loader: - kv 17: granitemoe.expert_used_count u32 = 8 llama_model_loader: - kv 18: general.file_type u32 = 0 llama_model_loader: - kv 19: granitemoe.vocab_size u32 = 49155 llama_model_loader: - kv 20: granitemoe.rope.dimension_count u32 = 64 llama_model_loader: - kv 21: tokenizer.ggml.add_space_prefix bool = false llama_model_loader: - kv 22: granitemoe.attention.scale f32 = 0.015625 llama_model_loader: - kv 23: granitemoe.embedding_scale f32 = 12.000000 llama_model_loader: - kv 24: granitemoe.residual_scale f32 = 0.220000 llama_model_loader: - kv 25: granitemoe.logit_scale f32 = 6.000000 llama_model_loader: - kv 26: tokenizer.ggml.model str = gpt2 llama_model_loader: - kv 27: tokenizer.ggml.pre str = refact llama_model_loader: - kv 28: tokenizer.ggml.tokens arr[str,49155] = ["<|end_of_text|>", "", "... llama_model_loader: - kv 29: tokenizer.ggml.token_type arr[i32,49155] = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ... llama_model_loader: - kv 30: tokenizer.ggml.merges arr[str,48891] = ["Ġ Ġ", "ĠĠ ĠĠ", "ĠĠĠĠ ĠĠ... llama_model_loader: - kv 31: tokenizer.ggml.bos_token_id u32 = 0 llama_model_loader: - kv 32: tokenizer.ggml.eos_token_id u32 = 0 llama_model_loader: - kv 33: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 34: tokenizer.ggml.padding_token_id u32 = 0 llama_model_loader: - kv 35: tokenizer.ggml.add_bos_token bool = false llama_model_loader: - kv 36: tokenizer.chat_template str = {%- if tools %}\n {{- '<|start_of_r... llama_model_loader: - kv 37: general.quantization_version u32 = 2 llama_model_loader: - type f32: 242 tensors llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect llm_load_vocab: special tokens cache size = 22 llm_load_vocab: token to piece cache size = 0.2826 MB llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = granitemoe llm_load_print_meta: vocab type = BPE llm_load_print_meta: n_vocab = 49155 llm_load_print_meta: n_merges = 48891 llm_load_print_meta: vocab_only = 0 llm_load_print_meta: n_ctx_train = 4096 llm_load_print_meta: n_embd = 1024 llm_load_print_meta: n_layer = 24 llm_load_print_meta: n_head = 16 llm_load_print_meta: n_head_kv = 8 llm_load_print_meta: n_rot = 64 llm_load_print_meta: n_swa = 0 llm_load_print_meta: n_embd_head_k = 64 llm_load_print_meta: n_embd_head_v = 64 llm_load_print_meta: n_gqa = 2 llm_load_print_meta: n_embd_k_gqa = 512 llm_load_print_meta: n_embd_v_gqa = 512 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-06 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: f_logit_scale = 6.0e+00 llm_load_print_meta: n_ff = 512 llm_load_print_meta: n_expert = 32 llm_load_print_meta: n_expert_used = 8 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 0 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_ctx_orig_yarn = 4096 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: ssm_dt_b_c_rms = 0 llm_load_print_meta: model type = ?B llm_load_print_meta: model ftype = all F32 llm_load_print_meta: model params = 1.33 B llm_load_print_meta: model size = 4.97 GiB (32.00 BPW) llm_load_print_meta: general.name = Granite 3.0 1b A400M Instruct llm_load_print_meta: BOS token = 0 '<|end_of_text|>' llm_load_print_meta: EOS token = 0 '<|end_of_text|>' llm_load_print_meta: UNK token = 0 '<|end_of_text|>' llm_load_print_meta: PAD token = 0 '<|end_of_text|>' llm_load_print_meta: LF token = 145 'Ä' llm_load_print_meta: EOG token = 0 '<|end_of_text|>' llm_load_print_meta: max token length = 512 llm_load_print_meta: f_embedding_scale = 12.000000 llm_load_print_meta: f_residual_scale = 0.220000 llm_load_print_meta: f_attention_scale = 0.015625 llm_load_tensors: ggml ctx size = 0.99 MiB llm_load_tensors: offloading 23 repeating layers to GPU llm_load_tensors: offloaded 23/25 layers to GPU llm_load_tensors: CPU buffer size = 5091.20 MiB llm_load_tensors: CUDA0 buffer size = 816.53 MiB llm_load_tensors: CUDA1 buffer size = 612.40 MiB llm_load_tensors: CUDA2 buffer size = 612.40 MiB llm_load_tensors: CUDA3 buffer size = 816.53 MiB llm_load_tensors: CUDA4 buffer size = 612.40 MiB llm_load_tensors: CUDA5 buffer size = 408.27 MiB llm_load_tensors: CUDA6 buffer size = 408.27 MiB llm_load_tensors: CUDA7 buffer size = 408.27 MiB ................................................................................ llama_new_context_with_model: n_ctx = 4096 llama_new_context_with_model: n_batch = 2048 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA_Host KV buffer size = 8.00 MiB llama_kv_cache_init: CUDA0 KV buffer size = 32.00 MiB llama_kv_cache_init: CUDA1 KV buffer size = 24.00 MiB llama_kv_cache_init: CUDA2 KV buffer size = 24.00 MiB llama_kv_cache_init: CUDA3 KV buffer size = 32.00 MiB llama_kv_cache_init: CUDA4 KV buffer size = 24.00 MiB llama_kv_cache_init: CUDA5 KV buffer size = 16.00 MiB llama_kv_cache_init: CUDA6 KV buffer size = 16.00 MiB llama_kv_cache_init: CUDA7 KV buffer size = 16.00 MiB llama_new_context_with_model: KV self size = 192.00 MiB, K (f16): 96.00 MiB, V (f16): 96.00 MiB llama_new_context_with_model: CUDA_Host output buffer size = 0.19 MiB llama_new_context_with_model: CUDA0 compute buffer size = 290.02 MiB llama_new_context_with_model: CUDA1 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA2 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA3 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA4 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA5 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA6 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA7 compute buffer size = 144.00 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 10.01 MiB llama_new_context_with_model: graph nodes = 1472 llama_new_context_with_model: graph splits = 23 common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable) main: llama threadpool init, n_threads = 40system_info: n_threads = 40 (n_threads_batch = 40) / 80 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | AMX_INT8 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampler seed: 3838250562
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 10, n_keep = 0
hi NAME: Hello! How can I assist you today
When run with
llama-cli -m <model> -ngl 24 -p "hi" -n 10
, the output is a small number of sane tokens, followed by a sequence of4
slogs
With
25/25
layers, the results are similar to24
, but not identicallogs
When I disable kv offload (
--no-kv-offload
), the results are sane, but also produce GGML backend error messageslogs
Setup Repro
huggingface-cli download ibm-granite/granite-3.0-1b-a400m-instruct --local-dir granite-3.0-1b-a400m-instruct # Convert to F32, F16, and Q4_K_M convert_hf_to_gguf.py granite-3.0-1b-a400m-instruct/ --outtype f32 convert_hf_to_gguf.py granite-3.0-1b-a400m-instruct/ llama-quantize granite-3.0-1b-a400m-instruct/granite-3.0-1B-a400M-instruct-F16.gguf Q4_K_M
Name and Version
What operating system are you seeing the problem on?
Relevant log output
(see above in details)
The text was updated successfully, but these errors were encountered: