-
Notifications
You must be signed in to change notification settings - Fork 12.6k
CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 #15131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test-backend-ops with 3090 Ti:
|
6eeaf23
to
781cbe0
Compare
781cbe0
to
52d9ccc
Compare
I dont really like this part for several reasons: For one thing for small networks trained for random industrial monitoring/computer vision tasks are often trained at fp32. I find ggml pretty useful as a general tensor processing library as nothing else c++ has the combination of being as self contained dependency wise while having as broad hardware support. Needless to say, for the generic tensor processing use case simply down casting is not acceptable. I also just generally find an api claiming to work in fp32 but in reality giving you fp16 with a couple of extra exponent bits pretty tacky. I guess in the end it do sent really matter for me personally since when ported to AMD gpus we can have full precision, as the matrix accelerators there have full fp32 (and even fp64) support. |
I understand your point regarding precision and can't say that I disagree with it but I think that "precision" as a function of Since you're already here, do you happen to know the correct way to get a 2x BF16 struct like |
__hip_bfloat162, part of hip_bf16.h |
with regards to tf32 matmuls, the equivalent pytorch issue has some futher useful discussion on the topic pytorch/pytorch#67384 |
side note (for myself really): we should not be using hip_bfloat16 at all anymore, its deprecated. |
I tried including My opinion regarding precision is that ggml should just go with the fast option if it can be reasonably assumed that the precision will be good enough for most cases and only strictly adhere to e.g. IEEE754 if explicitly requested by the user. |
Yeah, you can just use short2 and i will deal with this later
I think thats ok for the most part, but by selecting fp32, the user has already told you that he cares mostly about precision, so imo in this case this is a bad idea. |
We already downcast in almost every matrix multiplication, I don't think this needs to be an exception. However, it would be good if it respected |
66eab7a
to
52d9ccc
Compare
This reverts commit 1d72c84.
This PR adds a GEMM kernel for FP32/FP16/BF16 and batch sizes <= 16. The kernel is written in such a way that it does not need any synchronization between warps except for a constant number of times in the end. It can also directly handle mixed input/output types without the need for auxiliary kernels. The kernel uses tensor cores so FP16 needs Turing or newer, BF16 and FP32 need Ampere or newer. FP32 data is downcast to TF32 but I think that for neural networks this does not make a meaningful difference (and there would only ever be a difference for a model trained at FP32 precision in the first place). The end-to-end speedup for floating-point models is something like 5-20%, quantized models without FlashAttention also gain a few percent.
I renamed
NEW_MMA_AVAILABLE
toTURING_MMA_AVAILABLE
and added a new defineAMPERE_MMA_AVAILABLE
.Support for
MUL_MAT_ID
is not yet implemented.Performance changes