-
Notifications
You must be signed in to change notification settings - Fork 12.5k
CUDA: fix crash on large batch size for quant. MoE #13537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: fix crash on large batch size for quant. MoE #13537
Conversation
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1( | |||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; | |||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; | |||
|
|||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; | |||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blockDim.y
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blockDim
refers to the maximum extents of threadIdx
. The configuration of threads was not changed, therefore blockDim.x
is still correct.
Yeah, this fixed it for me - thanks! |
Should fix issue described in #13435 (comment) .
This PR swaps the x and y dimensions of the CUDA grid for quantizing the activations since the x dimension has a higher maximum size.