-
Notifications
You must be signed in to change notification settings - Fork 372
Description
Feature Context
Models which are fully supported in TRT, except for their input type being a collection should be able to be fully-compiled in Torch-TRT. Considering that Torch-executed list packing and list unpacking code is already being inserted (by necessity) even when models are fully supported, there should not be a need to disable full compilation when providing complex input types. Additionally, operators including prim::ListUnpack
should not be added to torch_executed_ops
automatically upon using input_signature
, as they are currently, since evaluators for them exist.
Desired Solution
The preferred solution is to remove the requirement for require_full_compilation=False
when using input_signature
and to remove the requirement that collection-based operators be executed in fallback:
TensorRT/py/torch_tensorrt/ts/_compile_spec.py
Lines 259 to 300 in 835abf0
elif compile_spec["input_signature"] is not None: | |
log( | |
Level.Warning, | |
"Input signature parsing is an experimental feature, behavior and APIs may change", | |
) | |
signature = _parse_input_signature(compile_spec["input_signature"]) | |
info.input_signature = _C.InputSignature(signature) # py_object | |
if not compile_spec["torch_fallback"]["enabled"]: | |
raise ValueError( | |
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" | |
) | |
log( | |
Level.Debug, | |
"Grouped inputs currently requires additional settings to enable the feature", | |
) | |
log( | |
Level.Debug, | |
"""Adding the following ops to torch_executed_ops: | |
- aten::__getitem__ | |
- prim::ListConstruct | |
- prim::ListUnpack | |
- prim::TupleIndex | |
- prim::TupleConstruct | |
- prim::TupleUnpack | |
""", | |
) | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
"aten::__getitem__" | |
) | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
"prim::ListConstruct" | |
) | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
"prim::TupleConstruct" | |
) | |
compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
"prim::TupleUnpack" | |
) |
This would require modification of the C++
core
code as well, to ensure that relaxing this requirement will not cause further issues with the existing compilation phases.
Additional Context
A proof-of-concept for this feature already exists in PR #1599, which could be used as a template to enable full-compilation functionality for collection inputs as well. This would complete the plan for Collection IO as discussed in #629 (comment).