Skip to content

Conversation

jeffbolznv
Copy link
Collaborator

I think glslang will translate an access like x[i][1].z to

OpAccessChain ... x, i, 1, 2
OpLoad float16_t ...

rather than loading all of x[i] in a single OpLoad. Change the code to explicitly load the vector/matrix.

Perf on 4070 using coopmat1:

4070 before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 10 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |       2258.49 ± 5.01 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |       2288.98 ± 6.67 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           pp512 |       1131.04 ± 1.90 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           pp512 |    13741.32 ± 243.38 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           pp512 |    12725.26 ± 138.57 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           pp512 |      4196.10 ± 11.26 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2224.48 ± 17.12 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           pp512 |      2633.34 ± 13.63 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |       2481.31 ± 9.15 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2179.18 ± 12.21 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |      3831.32 ± 72.35 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           pp512 |      3337.53 ± 10.32 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           pp512 |     5849.98 ± 170.81 |

4070 after:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 10 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |       2475.83 ± 4.68 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |       2440.91 ± 7.36 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           pp512 |       1218.58 ± 1.25 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           pp512 |    15013.23 ± 281.91 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           pp512 |    13817.10 ± 209.06 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           pp512 |      4601.85 ± 14.15 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2337.09 ± 20.95 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           pp512 |      2889.11 ± 11.30 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |       2578.32 ± 8.07 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2185.16 ± 11.06 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |      4162.22 ± 81.04 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           pp512 |      3832.55 ± 13.41 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           pp512 |     6384.55 ± 517.12 |

I think glslang will translate an access like x[i][1].z to
OpAccessChain ... x, i, 1, 2
OpLoad float16_t ...

rather than loading all of x[i] in a single OpLoad. Change the
code to explicitly load the vector/matrix.
@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner September 2, 2025 03:19
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Sep 2, 2025
@0cc4m
Copy link
Collaborator

0cc4m commented Sep 2, 2025

That was one of the things I did in #12515, but never finished. I should get back to cleaning up the mul_mm.comp file.

@0cc4m
Copy link
Collaborator

0cc4m commented Sep 7, 2025

gpu_info backends model_type model_size flash_attn test avg_ts(master) avg_ts(pr) %
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 13B Q4_K - Small 12.61 GiB 0 pp512 93.77 93.96 +0.2%
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 13B Q4_K - Small 12.61 GiB 0 tg128 25.94 25.92 -0.1%
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 8B Q4_K - Small 4.36 GiB 0 pp512 292.78 291.76 -0.3%
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 8B Q4_K - Small 4.36 GiB 0 tg128 71.63 71.30 -0.5%
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 8B Q8_0 7.95 GiB 0 pp512 650.77 651.45 +0.1%
AMD Radeon (TM) Pro VII (RADV VEGA20) Vulkan llama 8B Q8_0 7.95 GiB 0 tg128 70.74 70.27 -0.7%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 13B Q4_K - Small 12.61 GiB 0 pp512 32.34 32.34 +0.0%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 13B Q4_K - Small 12.61 GiB 0 tg128 17.00 16.98 -0.1%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 8B Q4_K - Small 4.36 GiB 0 pp512 100.39 100.45 +0.1%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 8B Q4_K - Small 4.36 GiB 0 tg128 37.82 37.75 -0.2%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 8B Q8_0 7.95 GiB 0 pp512 745.27 741.57 -0.5%
Intel(R) Arc(tm) A770 Graphics (DG2) Vulkan llama 8B Q8_0 7.95 GiB 0 tg128 34.89 34.70 -0.6%
NVIDIA GeForce RTX 3090 Vulkan llama 13B Q4_K - Small 12.61 GiB 0 pp512 847.88 883.29 +4.2%
NVIDIA GeForce RTX 3090 Vulkan llama 13B Q4_K - Small 12.61 GiB 0 tg128 44.05 43.91 -0.3%
NVIDIA GeForce RTX 3090 Vulkan llama 8B Q4_K - Small 4.36 GiB 0 pp512 2403.73 2556.77 +6.4%
NVIDIA GeForce RTX 3090 Vulkan llama 8B Q4_K - Small 4.36 GiB 0 tg128 121.55 121.07 -0.4%
NVIDIA GeForce RTX 3090 Vulkan llama 8B Q8_0 7.95 GiB 0 pp512 3027.63 3191.02 +5.4%
NVIDIA GeForce RTX 3090 Vulkan llama 8B Q8_0 7.95 GiB 0 tg128 90.85 90.81 -0.0%

This only seems to work for Nvidia.

@0cc4m 0cc4m merged commit 267e998 into ggml-org:master Sep 7, 2025
48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants