-
Notifications
You must be signed in to change notification settings - Fork 370
feat: Add Selective ATen decompositions #2173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
34a190e to
bdb06d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
bdb06d8 to
368a20e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
| ENABLED_TORCH_DECOMPOSITIONS: Dict[ | ||
| torch._ops.OpOverload, Callable | ||
| ] = get_torch_decompositions(enabled_decompositions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the decompositions are sourced directly from Torch's main registry (via get_torch_decompositions) and may not exactly match with the _core_aten_decompositions. This is because certain decompositions which we depend on (such as native_layer_norm, may occasionally be removed from the core set).
Whenever Torch versions are upgraded, this list should be updated as well.
| ENABLED_TORCH_DECOMPOSITIONS: Dict[ | ||
| torch._ops.OpOverload, Callable | ||
| ] = get_torch_decompositions(enabled_decompositions) | ||
| TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decompositions are three dictionaries:
ENABLED_TORCH_DECOMPOSITIONS- the enabled decompositions we've pre-selectedCORE_ATEN_DECOMPOSITIONS_FILTERED(defined inget_decompositionsbelow) - the complete set of_core_aten_decompositionsTorch provides, minus the set of disabled decompositions. Note thatTORCH_DECOMPOSITIONSmay not be a subset of this setTORCH_TRT_DECOMPOSITIONS- the decompositions we've written ourselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would TORCH_DECOMPOSITIONS only include decompositions from the get_torch_decompoistions set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that TORCH_DECOMPOSITIONS may not be a subset of this set
From what I understand,
- it seems like some decompositions in
_core_aten_decompositionshave been removed eg:aten.native_layer_norm. ENABLED_TORCH_DECOMPOSITIONS is a more complete set (from previous commit maybe ). Is this correct ? - In that case, what if we move aten.native.layer_norm to TORCH_TRT_DECOMPOSITIONS since it is useful to us and maybe other useful ones instead of maintaining a
ENABLED_TORCH_DECOMPOSITIONSwhich overlaps with_core_aten_decompositionsone ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan
ENABLED_TORCH_DECOMPOSITIONS would only include decompositions from the _core_aten_decompositions set which are not also in disabled_decompositions
- The interpretation of
ENABLED_TORCH_DECOMPOSITIONSis correct - My initial intent for
TORCH_TRT_DECOMPOSITIONSwas that it would only store decompositions we specifically (custom) wrote, not ones sourced from Torch, aslayer_normwould be.
|
|
||
| def get_decompositions(): | ||
| return DECOMPOSITIONS | ||
| def get_decompositions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dictionary returned by get_decompositions is either ENABLED_TORCH_DECOMPOSITIONS or CORE_ATEN_DECOMPOSITIONS_FILTERED concatenated with our TORCH_TRT_DECOMPOSITIONS
|
Have we thought about what this might look like if its user accessible? |
|
Can we add a tool to monitor these decomposition sets similar to the opset coverage tool? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
8b12de5 to
806b348
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
|
The existing opset coverage tool is compatible with this PR, meaning that |
806b348 to
7fa036f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
7fa036f to
d49cadb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
- Add sets to selectively enable or disable decompositions in Torch - Add new runtime argument `enable_experimental_decompositions` to enable all core aten decompositions, or a pre-selected subset thereof - Improve documentation of compilation settings overall
- Add decorator-wrapper to perform import-time checks on decompositions and alert the user if any custom decompositions conflict with existing registered or specified operators - Simplify code logic for dictionary merging in `get_decompositions` function - Add safety logic to ensure invariants about the decompositions are not violated
d49cadb to
1e3d12e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
peri044
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
1e3d12e to
2064f4f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
Description
enable_experimental_decompositionsto enable all core aten decompositions, or a pre-selected subset thereofFixes #2160
Type of change
Checklist: