-
Notifications
You must be signed in to change notification settings - Fork 369
fix: Add support for truncate_long_and_double in FX
#1865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ref_output, | ||
| rtol=1e-04, | ||
| atol=1e-04, | ||
| check_dtype=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output data type will be different, since TRT cannot output int64 types
torch.int64 inputs in FXtruncate_long_and_double in FX
65bc360 to
ad8aecf
Compare
| elif dtype == torch.int64: | ||
| if truncate_long_and_double: | ||
| _LOGGER.warn( | ||
| "Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility" | ||
| ) | ||
| return trt.int32 | ||
| else: | ||
| raise TypeError( | ||
| "Detected Int64 Input which is not supported by tensorrt, enable compilation" | ||
| + "option truncate_long_and_double=True to cast input to Int32 for TRT Engine" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to the TorchScript path, allow the truncate_long_and_double argument to automatically cast inputs as needed by TRT Engines, while informing the user. This is primarily helpful for intermediate inputs (not user-provided), which happen to be long-type tensors (such as indices for embeddings).
|
@gs-olive is this PR still needed? |
|
Yes, this PR is still needed to support T5 in the |
|
Can we create a seperate PR for dynamo so we can land the feature there at least? |
- Add utility capabilities for accepting `int64` inputs to TRTModules to support multiple use cases - Support cases include situations where internal tensors in split modules are `int64` (generally used for indexing torch Tensors) - This also supports cases where the user wants to input `long` tensors as `forward` inputs - Add test cases to verify functionality and accuracy - Enable tests for `TRTModuleNext`, which are now fully supported on `main`
- Add support and testing for `double` type inputs
|
@gs-olive can you create separate PRs for each backend? Will be easier to merge then |
truncate_long_and_double in FXtruncate_long_and_double in Dynamo
truncate_long_and_double in Dynamotruncate_long in Dynamo
truncate_long in Dynamotruncate_long_and_double in FX
|
Closed in favor of the more robust #2021 (no need to manually downcast, have the FX graph/Dynamo utilities automatically handle this for us). |
Description
Fixes #1864
Addresses #1740
Type of change
Checklist: