Skip to content

musa: enable MMA #13149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

musa: enable MMA #13149

wants to merge 1 commit into from

Conversation

yeahdongcn
Copy link
Contributor

@yeahdongcn yeahdongcn commented Apr 28, 2025

Make sure to read the contributing guidelines before submitting a PR

This PR enables muBLAS and MMA support on MUSA (QY2).
Edit:

I will rebase once #13144 is merged into master.

Done

Testing Done

root@2f2c227f02d8:/ws# ./build/bin/llama-cli -m /models/models/Ling-lite/Ling-lite.Q4_K_M.gguf -ngl 999 -fa
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 MUSA devices:
  Device 0: MTT S4000, compute capability 2.2, VMM: yes
build: 5202 (fd00e2d2) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device MUSA0 (MTT S4000) - 43580 MiB free
llama_model_loader: loaded meta data with 44 key-value pairs and 367 tensors from /models/models/Ling-lite/Ling-lite.Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = bailingmoe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Ling Lite
llama_model_loader: - kv   3:                         general.size_label str              = 64x1.5B
llama_model_loader: - kv   4:                            general.license str              = mit
llama_model_loader: - kv   5:                               general.tags arr[str,1]       = ["text-generation"]
llama_model_loader: - kv   6:                     bailingmoe.block_count u32              = 28
llama_model_loader: - kv   7:                  bailingmoe.context_length u32              = 16384
llama_model_loader: - kv   8:                bailingmoe.embedding_length u32              = 2048
llama_model_loader: - kv   9:             bailingmoe.feed_forward_length u32              = 5632
llama_model_loader: - kv  10:            bailingmoe.attention.head_count u32              = 16
llama_model_loader: - kv  11:         bailingmoe.attention.head_count_kv u32              = 4
llama_model_loader: - kv  12:                  bailingmoe.rope.freq_base f32              = 600000.000000
llama_model_loader: - kv  13: bailingmoe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  14:               bailingmoe.expert_used_count u32              = 6
llama_model_loader: - kv  15:            bailingmoe.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:               bailingmoe.rope.scaling.type str              = none
llama_model_loader: - kv  17:       bailingmoe.leading_dense_block_count u32              = 0
llama_model_loader: - kv  18:                      bailingmoe.vocab_size u32              = 126464
llama_model_loader: - kv  19:      bailingmoe.expert_feed_forward_length u32              = 1408
llama_model_loader: - kv  20:            bailingmoe.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  21:                    bailingmoe.expert_count u32              = 64
llama_model_loader: - kv  22:             bailingmoe.expert_shared_count u32              = 2
llama_model_loader: - kv  23:             bailingmoe.expert_weights_norm bool             = true
llama_model_loader: - kv  24:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  25:                         tokenizer.ggml.pre str              = bailingmoe
llama_model_loader: - kv  26:                      tokenizer.ggml.tokens arr[str,126464]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  27:                  tokenizer.ggml.token_type arr[i32,126464]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  28:                      tokenizer.ggml.merges arr[str,125824]  = ["Ġ Ġ", "Ġ t", "i n", "Ġ a", "h e...
llama_model_loader: - kv  29:                tokenizer.ggml.bos_token_id u32              = 126080
llama_model_loader: - kv  30:                tokenizer.ggml.eos_token_id u32              = 126081
llama_model_loader: - kv  31:            tokenizer.ggml.padding_token_id u32              = 126081
llama_model_loader: - kv  32:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  33:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  34:                    tokenizer.chat_template str              = {% for message in messages %}{% set r...
llama_model_loader: - kv  35:               general.quantization_version u32              = 2
llama_model_loader: - kv  36:                          general.file_type u32              = 15
llama_model_loader: - kv  37:                                general.url str              = https://huggingface.co/mradermacher/L...
llama_model_loader: - kv  38:              mradermacher.quantize_version str              = 2
llama_model_loader: - kv  39:                  mradermacher.quantized_by str              = mradermacher
llama_model_loader: - kv  40:                  mradermacher.quantized_at str              = 2025-03-31T05:37:59+02:00
llama_model_loader: - kv  41:                  mradermacher.quantized_on str              = kaos
llama_model_loader: - kv  42:                         general.source.url str              = https://huggingface.co/inclusionAI/Li...
llama_model_loader: - kv  43:                  mradermacher.convert_type str              = hf
llama_model_loader: - type  f32:   85 tensors
llama_model_loader: - type q5_0:   14 tensors
llama_model_loader: - type q8_0:   14 tensors
llama_model_loader: - type q4_K:  225 tensors
llama_model_loader: - type q6_K:   29 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 10.40 GiB (5.32 BPW) 
load: special tokens cache size = 266
load: token to piece cache size = 0.8056 MB
print_info: arch             = bailingmoe
print_info: vocab_only       = 0
print_info: n_ctx_train      = 16384
print_info: n_embd           = 2048
print_info: n_layer          = 28
print_info: n_head           = 16
print_info: n_head_kv        = 4
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: n_swa_pattern    = 1
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 4
print_info: n_embd_k_gqa     = 512
print_info: n_embd_v_gqa     = 512
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 5632
print_info: n_expert         = 64
print_info: n_expert_used    = 6
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = none
print_info: freq_base_train  = 600000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 16384
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 0
print_info: ssm_d_inner      = 0
print_info: ssm_d_state      = 0
print_info: ssm_dt_rank      = 0
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = 16B
print_info: model params     = 16.80 B
print_info: general.name     = Ling Lite
print_info: n_layer_dense_lead   = 0
print_info: n_ff_exp             = 1408
print_info: n_expert_shared      = 2
print_info: expert_weights_scale = 1.0
print_info: expert_weights_norm  = 1
print_info: vocab type       = BPE
print_info: n_vocab          = 126464
print_info: n_merges         = 125824
print_info: BOS token        = 126080 '<|startoftext|>'
print_info: EOS token        = 126081 '<|endoftext|>'
print_info: EOT token        = 126081 '<|endoftext|>'
print_info: PAD token        = 126081 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: EOG token        = 126081 '<|endoftext|>'
print_info: max token length = 154
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 28 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 29/29 layers to GPU
load_tensors:        MUSA0 model buffer size = 10513.91 MiB
load_tensors:   CPU_Mapped model buffer size =   138.94 MiB
......................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 600000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (16384) -- the full capacity of the model will not be utilized
llama_context:  MUSA_Host  output buffer size =     0.48 MiB
init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 28, can_shift = 1
init:      MUSA0 KV buffer size =   224.00 MiB
llama_context: KV self size  =  224.00 MiB, K (f16):  112.00 MiB, V (f16):  112.00 MiB
llama_context:      MUSA0 compute buffer size =   259.00 MiB
llama_context:  MUSA_Host compute buffer size =    20.01 MiB
llama_context: graph nodes  = 1659
llama_context: graph splits = 58
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 6
main: chat template is available, enabling conversation mode (disable it with -no-cnv)
main: chat template example:
<role>SYSTEM</role>You are a helpful assistant<role>HUMAN</role>Hello<role>ASSISTANT</role>Hi there<role>HUMAN</role>How are you?<role>ASSISTANT</role>

system_info: n_threads = 6 (n_threads_batch = 6) / 12 | MUSA : PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 

main: interactive mode on.
sampler seed: 1499548142
sampler params: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 0

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.
 - Not using system message. To change it, set a different value via -sys PROMPT


> Hi
Hello! How can I assist you today?

> Do you know math?
Yes, I have knowledge of various mathematical concepts and can help with solving problems, explaining formulas, and providing information on different mathematical topics. If you have a specific math question or need help with something, feel free to ask!

> 
llama_perf_sampler_print:    sampling time =       3.30 ms /    60 runs   (    0.05 ms per token, 18192.84 tokens per second)
llama_perf_context_print:        load time =    2193.15 ms
llama_perf_context_print: prompt eval time =    1496.69 ms /    25 tokens (   59.87 ms per token,    16.70 tokens per second)
llama_perf_context_print:        eval time =    2534.62 ms /    54 runs   (   46.94 ms per token,    21.30 tokens per second)
llama_perf_context_print:       total time =   13538.23 ms /    79 tokens
Interrupted by user
root@2f2c227f02d8:/ws# 

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 28, 2025
@yeahdongcn
Copy link
Contributor Author

yeahdongcn commented Apr 28, 2025

So far, I can only get it working with -fa. Without -fa, I encounter either garbled characters in the LLM replies or repeated GGGGGGG....

@JohannesGaessler @slaren Could you please share some tips on how to debug this kind of issue? I'd appreciate it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you manually allocating and deallocating memory instead of using ggml_cuda_pool_alloc? Batched FP16 GEMM is used for attention without FlashAttention so most likely this is where the bug is. I don't remember what the synchronization behavior of cudaFree is but if it's done asynchronously from the kernel executions that would explain why you get incorrect results.

@JohannesGaessler
Copy link
Collaborator

Could you please share some tips on how to debug this kind of issue?

Run test-backend-ops -o MUL_MAT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants