Skip to content

CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 #15131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds a GEMM kernel for FP32/FP16/BF16 and batch sizes <= 16. The kernel is written in such a way that it does not need any synchronization between warps except for a constant number of times in the end. It can also directly handle mixed input/output types without the need for auxiliary kernels. The kernel uses tensor cores so FP16 needs Turing or newer, BF16 and FP32 need Ampere or newer. FP32 data is downcast to TF32 but I think that for neural networks this does not make a meaningful difference (and there would only ever be a difference for a model trained at FP32 precision in the first place). The end-to-end speedup for floating-point models is something like 5-20%, quantized models without FlashAttention also gain a few percent.

I renamed NEW_MMA_AVAILABLE to TURING_MMA_AVAILABLE and added a new define AMPERE_MMA_AVAILABLE.

Support for MUL_MAT_ID is not yet implemented.

GPU Model FlashAttention Microbatch size Test t/s master t/s cuda-mmf-2 Speedup
RTX 3090 llama 1B all F32 Yes 1 pp512 201.55 201.45 1.00
RTX 3090 llama 1B all F32 Yes 2 pp512 377.48 377.68 1.00
RTX 3090 llama 1B all F32 Yes 3 pp512 532.17 532.20 1.00
RTX 3090 llama 1B all F32 Yes 4 pp512 639.58 711.83 1.11
RTX 3090 llama 1B all F32 Yes 5 pp512 735.52 875.70 1.19
RTX 3090 llama 1B all F32 Yes 6 pp512 876.94 1046.83 1.19
RTX 3090 llama 1B all F32 Yes 7 pp512 1013.73 1206.49 1.19
RTX 3090 llama 1B all F32 Yes 8 pp512 1160.20 1392.15 1.20
RTX 3090 llama 1B all F32 Yes 9 pp512 1279.06 1489.12 1.16
RTX 3090 llama 1B all F32 Yes 10 pp512 1394.79 1619.51 1.16
RTX 3090 llama 1B all F32 Yes 11 pp512 1527.16 1780.27 1.17
RTX 3090 llama 1B all F32 Yes 12 pp512 1647.89 1932.23 1.17
RTX 3090 llama 1B all F32 Yes 13 pp512 1773.88 2060.67 1.16
RTX 3090 llama 1B all F32 Yes 14 pp512 1893.11 2216.77 1.17
RTX 3090 llama 1B all F32 Yes 15 pp512 2021.20 2328.22 1.15
RTX 3090 llama 1B all F32 Yes 16 pp512 2164.53 2511.36 1.16
RTX 3090 llama 8B BF16 Yes 1 pp512 58.13 58.14 1.00
RTX 3090 llama 8B BF16 Yes 2 pp512 108.93 111.21 1.02
RTX 3090 llama 8B BF16 Yes 3 pp512 142.63 164.94 1.16
RTX 3090 llama 8B BF16 Yes 4 pp512 193.33 218.84 1.13
RTX 3090 llama 8B BF16 Yes 5 pp512 237.77 268.22 1.13
RTX 3090 llama 8B BF16 Yes 6 pp512 282.56 319.46 1.13
RTX 3090 llama 8B BF16 Yes 7 pp512 327.77 368.87 1.13
RTX 3090 llama 8B BF16 Yes 8 pp512 377.23 423.93 1.12
RTX 3090 llama 8B BF16 Yes 9 pp512 414.08 451.52 1.09
RTX 3090 llama 8B BF16 Yes 10 pp512 452.71 492.62 1.09
RTX 3090 llama 8B BF16 Yes 11 pp512 498.10 541.52 1.09
RTX 3090 llama 8B BF16 Yes 12 pp512 543.26 587.35 1.08
RTX 3090 llama 8B BF16 Yes 13 pp512 580.98 627.34 1.08
RTX 3090 llama 8B BF16 Yes 14 pp512 625.99 673.90 1.08
RTX 3090 llama 8B BF16 Yes 15 pp512 659.83 710.28 1.08
RTX 3090 llama 8B BF16 Yes 16 pp512 716.32 766.07 1.07
RTX 3090 llama 8B F16 Yes 1 pp512 58.25 58.27 1.00
RTX 3090 llama 8B F16 Yes 2 pp512 109.65 111.04 1.01
RTX 3090 llama 8B F16 Yes 3 pp512 141.37 164.82 1.17
RTX 3090 llama 8B F16 Yes 4 pp512 198.21 218.88 1.10
RTX 3090 llama 8B F16 Yes 5 pp512 242.82 268.60 1.11
RTX 3090 llama 8B F16 Yes 6 pp512 289.38 319.61 1.10
RTX 3090 llama 8B F16 Yes 7 pp512 340.19 368.54 1.08
RTX 3090 llama 8B F16 Yes 8 pp512 390.73 423.63 1.08
RTX 3090 llama 8B F16 Yes 9 pp512 427.20 450.76 1.06
RTX 3090 llama 8B F16 Yes 10 pp512 468.15 492.08 1.05
RTX 3090 llama 8B F16 Yes 11 pp512 512.35 539.89 1.05
RTX 3090 llama 8B F16 Yes 12 pp512 559.50 585.90 1.05
RTX 3090 llama 8B F16 Yes 13 pp512 598.67 625.12 1.04
RTX 3090 llama 8B F16 Yes 14 pp512 644.39 671.95 1.04
RTX 3090 llama 8B F16 Yes 15 pp512 680.68 706.71 1.04
RTX 3090 llama 8B F16 Yes 16 pp512 739.00 764.17 1.03
RTX 3090 llama 8B Q4_0 No 1 pp512 148.11 148.35 1.00
RTX 3090 llama 8B Q4_0 No 2 pp512 259.93 274.18 1.05
RTX 3090 llama 8B Q4_0 No 3 pp512 357.29 388.12 1.09
RTX 3090 llama 8B Q4_0 No 4 pp512 446.07 461.47 1.03
RTX 3090 llama 8B Q4_0 No 5 pp512 494.74 511.09 1.03
RTX 3090 llama 8B Q4_0 No 6 pp512 509.57 523.14 1.03
RTX 3090 llama 8B Q4_0 No 7 pp512 537.10 549.62 1.02
RTX 3090 llama 8B Q4_0 No 8 pp512 553.09 563.81 1.02
RTX 3090 llama 8B Q4_0 No 9 pp512 698.41 716.33 1.03
RTX 3090 llama 8B Q4_0 No 10 pp512 771.16 790.32 1.02
RTX 3090 llama 8B Q4_0 No 11 pp512 839.78 864.62 1.03
RTX 3090 llama 8B Q4_0 No 12 pp512 913.80 931.62 1.02
RTX 3090 llama 8B Q4_0 No 13 pp512 974.63 995.87 1.02
RTX 3090 llama 8B Q4_0 No 14 pp512 1042.64 1062.90 1.02
RTX 3090 llama 8B Q4_0 No 15 pp512 1108.68 1140.60 1.03
RTX 3090 llama 8B Q4_0 No 16 pp512 1204.02 1228.97 1.02
RTX 4090 llama 1B all F32 Yes 1 pp512 227.50 227.57 1.00
RTX 4090 llama 1B all F32 Yes 2 pp512 428.70 428.18 1.00
RTX 4090 llama 1B all F32 Yes 3 pp512 631.20 630.64 1.00
RTX 4090 llama 1B all F32 Yes 4 pp512 832.14 827.52 0.99
RTX 4090 llama 1B all F32 Yes 5 pp512 1018.63 1015.63 1.00
RTX 4090 llama 1B all F32 Yes 6 pp512 1189.68 1214.84 1.02
RTX 4090 llama 1B all F32 Yes 7 pp512 1362.52 1404.66 1.03
RTX 4090 llama 1B all F32 Yes 8 pp512 1540.00 1621.00 1.05
RTX 4090 llama 1B all F32 Yes 9 pp512 1595.33 1738.44 1.09
RTX 4090 llama 1B all F32 Yes 10 pp512 1759.14 1900.92 1.08
RTX 4090 llama 1B all F32 Yes 11 pp512 1940.24 2103.03 1.08
RTX 4090 llama 1B all F32 Yes 12 pp512 2165.25 2292.59 1.06
RTX 4090 llama 1B all F32 Yes 13 pp512 2324.30 2453.69 1.06
RTX 4090 llama 1B all F32 Yes 14 pp512 2501.92 2641.74 1.06
RTX 4090 llama 1B all F32 Yes 15 pp512 2538.54 2797.14 1.10
RTX 4090 llama 1B all F32 Yes 16 pp512 2755.12 3039.66 1.10
RTX 4090 llama 8B BF16 Yes 1 pp512 65.21 65.20 1.00
RTX 4090 llama 8B BF16 Yes 2 pp512 124.64 124.32 1.00
RTX 4090 llama 8B BF16 Yes 3 pp512 184.24 185.23 1.01
RTX 4090 llama 8B BF16 Yes 4 pp512 245.00 246.37 1.01
RTX 4090 llama 8B BF16 Yes 5 pp512 286.69 304.10 1.06
RTX 4090 llama 8B BF16 Yes 6 pp512 342.47 363.19 1.06
RTX 4090 llama 8B BF16 Yes 7 pp512 396.75 420.78 1.06
RTX 4090 llama 8B BF16 Yes 8 pp512 457.75 485.16 1.06
RTX 4090 llama 8B BF16 Yes 9 pp512 504.39 529.01 1.05
RTX 4090 llama 8B BF16 Yes 10 pp512 552.86 579.13 1.05
RTX 4090 llama 8B BF16 Yes 11 pp512 609.47 638.90 1.05
RTX 4090 llama 8B BF16 Yes 12 pp512 664.67 695.90 1.05
RTX 4090 llama 8B BF16 Yes 13 pp512 712.77 745.72 1.05
RTX 4090 llama 8B BF16 Yes 14 pp512 768.53 803.49 1.05
RTX 4090 llama 8B BF16 Yes 15 pp512 812.78 849.53 1.05
RTX 4090 llama 8B BF16 Yes 16 pp512 883.85 925.89 1.05
RTX 4090 llama 8B F16 Yes 1 pp512 65.22 65.22 1.00
RTX 4090 llama 8B F16 Yes 2 pp512 124.70 124.33 1.00
RTX 4090 llama 8B F16 Yes 3 pp512 184.00 185.19 1.01
RTX 4090 llama 8B F16 Yes 4 pp512 245.26 246.31 1.00
RTX 4090 llama 8B F16 Yes 5 pp512 284.84 304.20 1.07
RTX 4090 llama 8B F16 Yes 6 pp512 339.99 363.29 1.07
RTX 4090 llama 8B F16 Yes 7 pp512 393.74 420.65 1.07
RTX 4090 llama 8B F16 Yes 8 pp512 454.02 485.09 1.07
RTX 4090 llama 8B F16 Yes 9 pp512 499.39 529.16 1.06
RTX 4090 llama 8B F16 Yes 10 pp512 546.86 579.04 1.06
RTX 4090 llama 8B F16 Yes 11 pp512 602.03 638.47 1.06
RTX 4090 llama 8B F16 Yes 12 pp512 662.61 695.72 1.05
RTX 4090 llama 8B F16 Yes 13 pp512 711.06 745.89 1.05
RTX 4090 llama 8B F16 Yes 14 pp512 766.37 803.39 1.05
RTX 4090 llama 8B F16 Yes 15 pp512 811.07 849.13 1.05
RTX 4090 llama 8B F16 Yes 16 pp512 881.10 925.64 1.05
RTX 4090 llama 8B Q4_0 No 1 pp512 179.53 179.39 1.00
RTX 4090 llama 8B Q4_0 No 2 pp512 327.26 334.71 1.02
RTX 4090 llama 8B Q4_0 No 3 pp512 479.97 497.19 1.04
RTX 4090 llama 8B Q4_0 No 4 pp512 632.18 659.96 1.04
RTX 4090 llama 8B Q4_0 No 5 pp512 746.06 772.88 1.04
RTX 4090 llama 8B Q4_0 No 6 pp512 872.40 905.02 1.04
RTX 4090 llama 8B Q4_0 No 7 pp512 966.87 997.43 1.03
RTX 4090 llama 8B Q4_0 No 8 pp512 1055.74 1090.62 1.03
RTX 4090 llama 8B Q4_0 No 9 pp512 1076.77 1107.31 1.03
RTX 4090 llama 8B Q4_0 No 10 pp512 1183.07 1216.63 1.03
RTX 4090 llama 8B Q4_0 No 11 pp512 1305.52 1340.94 1.03
RTX 4090 llama 8B Q4_0 No 12 pp512 1424.18 1460.17 1.03
RTX 4090 llama 8B Q4_0 No 13 pp512 1525.92 1564.21 1.03
RTX 4090 llama 8B Q4_0 No 14 pp512 1637.70 1684.51 1.03
RTX 4090 llama 8B Q4_0 No 15 pp512 1737.93 1785.59 1.03
RTX 4090 llama 8B Q4_0 No 16 pp512 1891.88 1930.50 1.02
Performance changes

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 6, 2025
@slaren
Copy link
Member

slaren commented Aug 6, 2025

test-backend-ops with 3090 Ti:

Backend GGML op Op parameters TFLOPS master TFLOPS cuda-mmf-3 Speedup
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 0.96 0.97 1.00
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 1.88 1.91 1.02
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 2.32 2.87 1.24
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 3.54 3.81 1.07
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 4.41 4.74 1.07
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=6,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 5.29 5.67 1.07
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=7,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 6.17 6.60 1.07
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 7.02 7.54 1.07
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=9,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 7.89 8.33 1.06
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=10,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 8.74 9.21 1.05
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=11,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 9.61 10.11 1.05
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=12,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 10.45 10.92 1.04
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=13,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 11.28 11.83 1.05
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=14,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 12.14 12.67 1.04
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=15,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 13.00 13.60 1.05
CUDA0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=16,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 13.80 14.48 1.05
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 0.96 0.96 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 1.89 1.91 1.01
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 2.24 2.87 1.28
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 3.48 3.82 1.10
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 4.34 4.74 1.09
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=6,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 5.20 5.68 1.09
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=7,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 6.16 6.60 1.07
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 7.03 7.53 1.07
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=9,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 7.87 8.33 1.06
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=10,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 8.73 9.23 1.06
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=11,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 9.54 10.12 1.06
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=12,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 10.39 10.90 1.05
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=13,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 11.26 11.84 1.05
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=14,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 12.07 12.71 1.05
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=15,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 12.90 13.59 1.05
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=16,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 13.74 14.47 1.05
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 0.49 0.49 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 0.97 0.97 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 1.37 1.37 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 1.79 1.93 1.08
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 2.32 2.41 1.04
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=6,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 2.78 2.88 1.04
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=7,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 3.24 3.36 1.04
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 3.70 3.83 1.04
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=9,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 4.16 4.26 1.03
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=10,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 4.62 4.74 1.02
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=11,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 5.07 5.20 1.03
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=12,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 5.53 5.68 1.03
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=13,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 5.99 6.14 1.02
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=14,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 6.44 6.58 1.02
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=15,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 6.90 7.05 1.02
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=16,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0 7.35 7.52 1.02

@JohannesGaessler JohannesGaessler force-pushed the cuda-mmf-3 branch 5 times, most recently from 6eeaf23 to 781cbe0 Compare August 6, 2025 20:40
@IMbackK
Copy link
Collaborator

IMbackK commented Aug 6, 2025

FP32 data is downcast to TF32 but I think that for neural networks this does not make a meaningful difference (and there would only ever be a difference for a model trained at FP32 precision in the first place).

I dont really like this part for several reasons:

For one thing for small networks trained for random industrial monitoring/computer vision tasks are often trained at fp32.
Inference at fp32 is here is done as a baseline to check against the target inference datatype. Down casting to tf32 without informing the user of ggmls api about it unknowingly falsifies the results the user gets for this task.
I think if someone actively chooses fp32 for inference of a neural network, the probably have a specific reason in mind and want very high numeric accuracy.

I find ggml pretty useful as a general tensor processing library as nothing else c++ has the combination of being as self contained dependency wise while having as broad hardware support. Needless to say, for the generic tensor processing use case simply down casting is not acceptable.

I also just generally find an api claiming to work in fp32 but in reality giving you fp16 with a couple of extra exponent bits pretty tacky.

I guess in the end it do sent really matter for me personally since when ported to AMD gpus we can have full precision, as the matrix accelerators there have full fp32 (and even fp64) support.

@JohannesGaessler
Copy link
Collaborator Author

I understand your point regarding precision and can't say that I disagree with it but I think that "precision" as a function of ggml_prec is defined poorly in ggml in the first place. For a matrix multiplication in particular I think the by far most important part is which precision is used for the accumulation, there FP32 is always used.

Since you're already here, do you happen to know the correct way to get a 2x BF16 struct like nv_bfloat162 in HIP? I can't seem to figure it out and resorted to using short2 as a hack so the code would compile.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 6, 2025

__hip_bfloat162, part of hip_bf16.h

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 6, 2025

with regards to tf32 matmuls, the equivalent pytorch issue has some futher useful discussion on the topic pytorch/pytorch#67384

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 6, 2025

side note (for myself really): we should not be using hip_bfloat16 at all anymore, its deprecated.

@JohannesGaessler
Copy link
Collaborator Author

I tried including bf16.h but then I get an error during linking, presumably since I'm also including bfloat16.h. When I include only bf16.h and try to use __hip_bfloat16 I can't get the preexisting code to compile.

My opinion regarding precision is that ggml should just go with the fast option if it can be reasonably assumed that the precision will be good enough for most cases and only strictly adhere to e.g. IEEE754 if explicitly requested by the user.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 6, 2025

I tried including bf16.h but then I get an error during linking, presumably since I'm also including bfloat16.h. When I include only bf16.h and try to use __hip_bfloat16 I can't get the preexisting code to compile.

Yeah, you can just use short2 and i will deal with this later

My opinion regarding precision is that ggml should just go with the fast option if it can be reasonably assumed that the precision will be good enough for most cases and only strictly adhere to e.g. IEEE754 if explicitly requested by the user.

I think thats ok for the most part, but by selecting fp32, the user has already told you that he cares mostly about precision, so imo in this case this is a bad idea.

@slaren
Copy link
Member

slaren commented Aug 6, 2025

We already downcast in almost every matrix multiplication, I don't think this needs to be an exception. However, it would be good if it respected GGML_PREC_F32 when requested.

@JohannesGaessler JohannesGaessler merged commit 1d72c84 into ggml-org:master Aug 7, 2025
93 of 94 checks passed
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants