-
Notifications
You must be signed in to change notification settings - Fork 293
metal lowbit kernels: qmv_fast optimization #2167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2167
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit c200b29 with merge base 4850998 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
578cf82
to
c200b29
Compare
oh this is the one you were talking about |
@@ -64,12 +64,11 @@ using namespace metal; | |||
@param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit | |||
values, along K dim, packed together. | |||
@param [in] scales_ptr is scales ptr corresponding each | |||
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output | |||
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this work for gemm as well?
w[0] = float(b[0] & 0x07); | ||
w[1] = float((b[0] & 0x38) >> 3); | ||
w[2] = float(((b[0] & 0xc0) >> 6) | ((b[1] & 0x01) << 2)); | ||
w[3] = float((b[1] & 0x0e) >> 1); | ||
w[4] = float((b[1] & 0x70) >> 4); | ||
w[5] = float(((b[1] & 0x80) >> 7) | ((b[2] & 0x03) << 1)); | ||
w[6] = float((b[2] & 0x1c) >> 2); | ||
w[7] = float((b[2] & 0xe0) >> 5); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am definitely surprised that this is better
Summary
This PR does the following modifications:
Performance Improvement
The following tables show the impact of this optimization on Llama3.1/3.2 decode, using torchchat + mps compile for text generation on an M1 Max 64GB (24 GPU cores, 10 CPU cores)
Llama 3.1-8B
python3 torchchat.py generate llama3.1-base --device mps --dtype float16 --quantize '{"linear:afpwx": {"bitwidth": #BITS, "groupsize": 64}}' --prompt "Once upon a time," --num-samples 5 --compile
Llama 3.2-3B
python3 torchchat.py generate llama3.2-3b-base --device mps --dtype float16 --quantize '{"linear:afpwx": {"bitwidth": #BITS, "groupsize": 64}}' --prompt "Once upon a time," --num-samples 5 --compile
Llama 3.2-1B
python3 torchchat.py generate llama3.2-1b-base --device mps --dtype float16 --quantize '{"linear:afpwx": {"bitwidth": #BITS, "groupsize": 64}}' --prompt "Once upon a time," --num-samples 5 --compile
Performance Summary
The table below summarizes torchchat's speed (tokens/second) on Metal backend on M1 Max after this change