Skip to content

✨[Feature] Automatic type promotion for elementwise ops in FX #2031

Closed
@gs-olive

Description

@gs-olive

Context

Currently, elementwise operators are not automatically type-promoted in FX in the same way they are in TS. This leads to bugs such as #1995, where the types are mismatched and TensorRT throws an error.

Feature Proposal

Using the TS type-promotion code as a starting point:

nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b) {
auto torch_type_a = util::TRTDataTypeToScalarType(type_a);
auto torch_type_b = util::TRTDataTypeToScalarType(type_b);
auto promo_type = at::promote_types(torch_type_a, torch_type_b);
auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type);
return trt_promo_type;
}

Implement a similar type-promotion scheme pointing to the converter_reorg_elementwise branch, here:
def convert_binary_elementwise(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
op_type: trt.ElementWiseOperation,
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
) -> TRTTensor:

Metadata

Metadata

Labels

No Activitycomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathsfeature requestNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions