-
Notifications
You must be signed in to change notification settings - Fork 375
Add per tensor fp8 quantization support for conv3d #3215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3215
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 94c2e60 with merge base 7e5d907 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
currently waiting for the fbgemm conv op to be available in nightly
|
2569be1 to
b5c8ca5
Compare
b5c8ca5 to
2ccc619
Compare
2ccc619 to
4c6e979
Compare
torchao/quantization/quant_api.py
Outdated
| activation_granularity, weight_granularity = granularity | ||
|
|
||
| if not _fp8_mm_compat(weight): | ||
| if weight.dim() != 5 and not _fp8_mm_compat(weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 seems a bit arbitrary here without the context, should we add a comment that this is for conv3d?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK added
| "is_MI300", | ||
| "is_sm_at_least_89", | ||
| "is_sm_at_least_90", | ||
| "is_sm_at_least_100", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super related to this PR but I wonder if we should stop exposing these, we don't expect users to call these themselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, what should user call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh should we use is_sm_version? but I think we need is_sm_at_least
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I mean should users even call these helper functions? They know what GPUs they're running on. If they really want to check then maybe they should just check torch.cuda.get_device_capability() >= (10, 0) themselves instead of importing our utils. I'd like to keep them private if possible (in a separate PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see, makes sense, yeah this should just be dev only, not user facing
| stride, | ||
| padding, | ||
| dilation, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we also need to check kernel preference? Like if it's "torch" maybe we should throw an exception since we don't support that yet?
3339fd7 to
642cec4
Compare
Summary: att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig API: ``` config = Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ) _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) quantize_(quantized_model, config, filter_fn=_is_conv3d) ``` Test Plan: pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
642cec4 to
94c2e60
Compare
Add per tensor fp8 quantization support conv3d Summary: att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig API: ``` config = Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ) _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) quantize_(quantized_model, config, filter_fn=_is_conv3d) ``` Test Plan: pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
Summary:
att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig
API:
Test Plan:
pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
Reviewers:
Subscribers:
Tasks:
Tags: