Skip to content

CUDA: more warps for mmvq on NVIDIA #5394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Feb 7, 2024

Writing a better matrix vector multiplication kernel has turned out to be more difficult than I anticipated. So I'm doing another optimization for the current implementation instead. On master currently only a single warp is used for each matrix row. But on NVIDIA GPUs the use of more warps seems to be beneficial in scenarios where the compute per memory bandwidth is relatively high (batch sizes > 1, k-quants, , GPUs with low compute). So this PR sets the number of warps per row to 4 for NVIDIA GPUs. On AMD (RX 6800) this seemed to be detrimental so there is no change. The option LLAMA_CUDA_MMV_Y no longer affects performance but in my testing changing this value was only beneficial for my RTX 3090 and the implementation in this PR performs better. On my systems the performance changes as follows:

GPU Model Batch size Test t/s master t/s cuda-faster-mmvq-8 Speedup
RTX 3090 llama 7B Q2_K_M 1 pp512 85.87 104.19 1.21
RTX 3090 llama 7B Q2_K_M 2 pp512 156.99 179.03 1.14
RTX 3090 llama 7B Q2_K_M 4 pp512 223.85 263.66 1.18
RTX 3090 llama 7B Q3_K_S 1 pp512 79.99 98.01 1.23
RTX 3090 llama 7B Q3_K_S 2 pp512 150.57 171.35 1.14
RTX 3090 llama 7B Q3_K_S 4 pp512 218.67 256.80 1.17
RTX 3090 llama 7B Q4_0 1 pp512 128.69 129.95 1.01
RTX 3090 llama 7B Q4_0 2 pp512 228.15 248.04 1.09
RTX 3090 llama 7B Q4_0 4 pp512 332.54 368.32 1.11
RTX 3090 llama 7B Q4_1 1 pp512 121.78 122.48 1.01
RTX 3090 llama 7B Q4_1 2 pp512 220.15 237.83 1.08
RTX 3090 llama 7B Q4_1 4 pp512 334.62 380.39 1.14
RTX 3090 llama 7B Q4_K_S 1 pp512 112.64 126.31 1.12
RTX 3090 llama 7B Q4_K_S 2 pp512 171.01 212.45 1.24
RTX 3090 llama 7B Q4_K_S 4 pp512 250.00 295.62 1.18
RTX 3090 llama 7B Q5_0 1 pp512 107.66 114.10 1.06
RTX 3090 llama 7B Q5_0 2 pp512 204.58 218.82 1.07
RTX 3090 llama 7B Q5_0 4 pp512 315.89 339.53 1.07
RTX 3090 llama 7B Q5_1 1 pp512 104.49 108.88 1.04
RTX 3090 llama 7B Q5_1 2 pp512 200.39 212.70 1.06
RTX 3090 llama 7B Q5_1 4 pp512 286.87 354.54 1.24
RTX 3090 llama 7B Q5_K_S 1 pp512 104.60 113.13 1.08
RTX 3090 llama 7B Q5_K_S 2 pp512 175.72 196.65 1.12
RTX 3090 llama 7B Q5_K_S 4 pp512 247.74 283.09 1.14
RTX 3090 llama 7B Q6_K 1 pp512 78.22 99.71 1.27
RTX 3090 llama 7B Q6_K 2 pp512 109.20 174.42 1.60
RTX 3090 llama 7B Q6_K 4 pp512 187.63 255.49 1.36
RTX 3090 llama 7B Q8_0 1 pp512 85.14 87.02 1.02
RTX 3090 llama 7B Q8_0 2 pp512 159.31 170.78 1.07
RTX 3090 llama 7B Q8_0 4 pp512 252.16 303.63 1.20
P40 llama 7B Q2_K_M 1 pp512 30.90 46.35 1.50
P40 llama 7B Q2_K_M 2 pp512 40.97 47.86 1.17
P40 llama 7B Q2_K_M 4 pp512 64.91 69.35 1.07
P40 llama 7B Q3_K_S 1 pp512 28.42 44.61 1.57
P40 llama 7B Q3_K_S 2 pp512 40.14 46.95 1.17
P40 llama 7B Q3_K_S 4 pp512 64.61 68.20 1.06
P40 llama 7B Q4_0 1 pp512 54.54 56.03 1.03
P40 llama 7B Q4_0 2 pp512 57.65 57.47 1.00
P40 llama 7B Q4_0 4 pp512 86.70 92.77 1.07
P40 llama 7B Q4_1 1 pp512 54.14 53.99 1.00
P40 llama 7B Q4_1 2 pp512 57.76 58.20 1.01
P40 llama 7B Q4_1 4 pp512 91.18 93.80 1.03
P40 llama 7B Q4_K_S 1 pp512 50.88 50.76 1.00
P40 llama 7B Q4_K_S 2 pp512 43.40 52.68 1.21
P40 llama 7B Q4_K_S 4 pp512 72.58 79.37 1.09
P40 llama 7B Q5_0 1 pp512 47.06 46.86 1.00
P40 llama 7B Q5_0 2 pp512 51.63 52.54 1.02
P40 llama 7B Q5_0 4 pp512 82.67 87.61 1.06
P40 llama 7B Q5_1 1 pp512 47.51 47.57 1.00
P40 llama 7B Q5_1 2 pp512 49.78 53.51 1.07
P40 llama 7B Q5_1 4 pp512 84.39 90.21 1.07
P40 llama 7B Q5_K_S 1 pp512 38.83 43.30 1.12
P40 llama 7B Q5_K_S 2 pp512 48.80 48.65 1.00
P40 llama 7B Q5_K_S 4 pp512 67.02 76.11 1.14
P40 llama 7B Q6_K 1 pp512 35.05 35.61 1.02
P40 llama 7B Q6_K 2 pp512 41.45 42.64 1.03
P40 llama 7B Q6_K 4 pp512 58.63 68.99 1.18
P40 llama 7B Q8_0 1 pp512 33.69 35.83 1.06
P40 llama 7B Q8_0 2 pp512 42.41 44.28 1.04
P40 llama 7B Q8_0 4 pp512 56.10 77.67 1.38

If you account for tail effects and the initial latency when launching a kernel the current implementation achieves ~90% of the maximum theoretical matrix vector multiplication performance for q8_0 on my RTX 3090.

@sorasoras
Copy link

@JohannesGaessler How could I test this in my RDNA3 GPU such as 7900XTX to see if there is any improvement?

@JohannesGaessler
Copy link
Collaborator Author

Setting MMVQ_NWARPS_AMD on line 5314 to 4 should work.

@sorasoras
Copy link

sorasoras commented Feb 7, 2024

Setting MMVQ_NWARPS_AMD on line 5314 to 4 should work.

RDNA3 behave the same as RDNA2, there is performance regression about 10%~

@8XXD8
Copy link

8XXD8 commented Feb 8, 2024

With Radeon VII and Vega there is a notable increase in TG when compiling with -DLLAMA_CUDA_MMV_Y=4

master no mmv_y

model size params backend ngl sm test t/s
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row pp 512 133.54 ± 0.04
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row tg 128 26.47 ± 0.02

master mmv_y=4

model size params backend ngl sm test t/s
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row pp 512 133.53 ± 0.05
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row tg 128 31.32 ± 0.03

build: 8504d2d (2097)

Now in this PR LLAMA_CUDA_MMV_Y is not working anymore, but MMVQ_NWARPS_AMD=4 brings back the performance

PR warp=1

model size params backend ngl sm test t/s
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row pp 512 133.51 ± 0.04
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row tg 128 26.31 ± 0.01
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row pp 512 42.79 ± 0.01
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row tg 128 9.12 ± 0.01

PR warp=4

model size params backend ngl sm test t/s
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row pp 512 133.60 ± 0.07
llama 13B Q6_K 9.95 GiB 13.02 B ROCm 99 row tg 128 32.40 ± 0.02
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row pp 512 42.79 ± 0.01
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row tg 128 12.52 ± 0.02

build: 7a0f63a (2094)

@JohannesGaessler
Copy link
Collaborator Author

I pushed a patch that enables the higher warp count for RDNA 1 and older. I don't have actually have access to an RDNA 1 card to test this on though. But if this causes a performance regression I expect someone to quickly open an issue at which point the threshold can just be adjusted.

@8XXD8
Copy link

8XXD8 commented Feb 8, 2024

Thank you, performance is good again

model size params backend ngl sm test t/s
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row pp 512 42.79 ± 0.01
llama 70B Q4_K - Medium 38.58 GiB 68.98 B ROCm 99 row tg 128 12.53 ± 0.00

build: f195490 (2095)

@slaren
Copy link
Member

slaren commented Feb 8, 2024

Do you think it could be worth extending this for batch sizes 5-8 now? I noticed that it becomes slower after 4.

model size params backend ngl test t/s
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 1 130.78 ± 10.68
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 2 249.47 ± 3.42
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 3 317.04 ± 5.28
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 4 365.03 ± 5.21
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 5 245.98 ± 1.70
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 6 293.72 ± 4.27
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 7 340.79 ± 4.57
llama 7B Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 8 388.66 ± 1.74

build: 00ebc31 (2094)

@JohannesGaessler
Copy link
Collaborator Author

I didn't test it but before I do that I would like to try a bit more in regards to a better way to do the calculation.

@JohannesGaessler JohannesGaessler merged commit 8e6a9d2 into ggml-org:master Feb 8, 2024
@Ph0rk0z
Copy link

Ph0rk0z commented Feb 10, 2024

After this (I think) I am up to 17.4 t/s on 70b. Feels like time to first token got longer and I went from 299t/s to 229 t/s PP. Still not back to peak perf but it's getting better.

I can confirm setting MMV_Y didn't seem to impact anything anymore.

This is dual 3090s with nvlink. I have taken the P40s down for now.

@JohannesGaessler
Copy link
Collaborator Author

Feels like time to first token got longer and I went from 299t/s to 229 t/s PP.

This PR had no effect whatsoever on batch sizes > 4 so I don't think that is related.

@Ph0rk0z
Copy link

Ph0rk0z commented Feb 10, 2024

I only use batch of 1.

@JohannesGaessler
Copy link
Collaborator Author

The time to first token is 90% a factor of matrix multiplications with batch sizes >> 1, either cuBLAS GEMM or mul_mat_q. Both of those operations were not touched by this PR. It doesn't matter what batch size you use for generation.

@Ph0rk0z
Copy link

Ph0rk0z commented Feb 10, 2024

You are right. I did more a/b testing and it has to be related to something else. My first generation seems to take a while.

Layer splitting is now faster. Tensor cores still lags. Not necessarily from this PR but with the current codebase. I need to test with 3090/3090/2080(22g) because the regression there previously was much larger than just with 2 GPU. This is definitely improvement though vs last week.

@JohannesGaessler
Copy link
Collaborator Author

Tensor cores still lags. Not necessarily from this PR but with the current codebase.

The biggest issue is that there is no support for the optimizations I added in #3110 .

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants