-
Notifications
You must be signed in to change notification settings - Fork 67
Added dynamic-shape 0/1 bucketing: "zero_nonzero" env var #1053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added dynamic-shape 0/1 bucketing: "zero_nonzero" env var #1053
Conversation
…th in a diff way while 1 just goes >= 2 similar to already present '_tensor_key' logic & this can be turned on with an environment variable 'HELION_SHAPE_BUCKETING' that default to 'min2', and disables 0/1 when using 'zero_nonzero'.
|
Hi @Itssshikhar! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
please rebase to main and make sure the lint is passing, and move your PR from draft to ready for review when you're ready |
|
Appreciate the comments. I've rebase the main to this branch and lint seems to be passing. Let me know if there is anything else. |
jansel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes the bucketing that happens inside the kernel.py specialization key, but I think it may have some correctness issues because there is still similar shape bucketing coming from ShapeEnv/FakeTensors (which are part of PyTorch). So the code the is generated may not be shape agnostic, but this will cause it to be used as if it is.
You should be able to surface this issue by adding tests that:
- Run the full kernel with varying 1-ness of shapes
- Inspect the generated output code (via assertExpectedJournal) so you can see if shape specialization is happening.
…ne function to autotuner (pytorch#1054) Rebased to main
…would still get specialized.
|
Okay, so you were right. Even though bucketing inside And the reason for this was because there were 2 different kernels coming from Not sure if I explained this correctly but let me know if there is something I've missed. |
| # When disabling 0/1 specialization (zero_nonzero), ensure non-zero dims are symbolic | ||
| if ( | ||
| not self.settings.static_shapes | ||
| and getattr(self.settings, "shape_bucketing", "min2") == "zero_nonzero" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need getattr here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I was just trying to be on the safe side. Did the same in _kernel_type but it's really not necessary.
| baseline_config, | ||
| prefix=f"Generated Triton code for {decorator}:", | ||
| ) | ||
| <<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me clear this up.
| def format_kernel_decorator(self, config: Config, settings: Settings) -> str: | ||
| """Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation.""" | ||
| # Include shape_bucketing only when non-default to keep logs compact | ||
| if getattr(settings, "shape_bucketing", "min2") != "min2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why getattr?
| ) | ||
| # Non-static path: bucket sizes for specialization. Default is 0/1/>=2 (as 2). | ||
| vals = tuple([min(s, 2) for s in obj.size()]) | ||
| if getattr(fn.settings, "shape_bucketing", "min2") == "zero_nonzero": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
| # Controls non-static shape specialization bucketing. When "min2" (default), | ||
| # we bucket dynamic sizes per-dimension into 0, 1, or >=2 (represented as 2). | ||
| # When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce churn. | ||
| shape_bucketing: Literal["min2", "zero_nonzero"] = dataclasses.field( | ||
| default_factory=_get_shape_bucketing | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some though, perhaps instead of adding a new config we should make static_shapes an enum of "all", "ones", "none". Since if I set static_shapes=True this does nothing.
We will need backcompat for True/False, but that might result in a cleaner config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so I was thinking we can do something like this:
- static_shape = "all" would be equivalent to setting static_shape=True
- static_shape = "ones" would be the "min2" case, meaning specialize 0/1.
- static_shape = "none" would be this "zero_nonzero" case, basically disabling 0/1 specialization.
To make backcompat for True/False, we can set them as True->"all" & False->"none" and then HELION_STATIC_SHAPES can go through "all", "ones", "none".
|
|
||
| y2 = torch.empty_like(x2) | ||
| y1 = torch.empty_like(x1) | ||
| pw_add(x2, y2) # compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should test the size==1 version first, since I'd expect doing it in this order would work without this change.
| if code1 == code2: | ||
| self.skipTest( | ||
| "Generated Triton is identical for M=1 and M=2; no singleton specialization detected" | ||
| ) | ||
| else: | ||
| # Expect differing code paths when singleton specialization is present | ||
| self.assertNotEqual(code1, code2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test cannot fail. It turns failures into skips.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove the conditional here.
| y2 = torch.empty_like(x2) | ||
| b2 = k2.bind((x2, y2)) | ||
| code2 = b2.to_triton_code() | ||
| self.assertExpectedJournal(code2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to journal the code twice if you assert they are the same on the next line.
Remove the extra entry from the *.expected file manually (rm the file and regernate).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense.
| for i in grid(x.size(0)): | ||
| for j in grid(x.size(1)): | ||
| for k in grid(x.size(2)): | ||
| out[i, j, k] = x[i, j, k] + 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do all the tests use grid rather than tile? This seems a bit odd.
| for i in grid(x.size(0)): | |
| for j in grid(x.size(1)): | |
| for k in grid(x.size(2)): | |
| out[i, j, k] = x[i, j, k] + 1.0 | |
| for i, j, k in tile(x.size()): | |
| out[i, j, k] = x[i, j, k] + 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more of a habit of using grid, coming from triton. using tile would make sense here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a reduction test (using the sum example kernel).
Have you tried manually running some of the examples with this flag set to shake out any other bugs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to find out what all bugs this change might introduce. Will do more tests, along with adding reduction kernel.
Resolves #934
Used a zero-nonzero specialization in which 0 is kept distinct while 1 just goes >= 2 similar to already present '_tensor_key' logic & this can be turned on with an environment variable 'HELION_SHAPE_BUCKETING' that defaults to 'min2' (current behavior: bucket dims are 0, 1, >=2), and disables 0/1 when using 'zero_nonzero'. Kept 0 as a separate bucket to deal with zero-numel edge cases. Added test for the same as well.