diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 94ac827ef4..ded4671e56 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -59,6 +59,14 @@ nvinfer1::ITensor* addUnpadding( } } +nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b) { + auto torch_type_a = util::TRTDataTypeToScalarType(type_a); + auto torch_type_b = util::TRTDataTypeToScalarType(type_b); + auto promo_type = at::promote_types(torch_type_a, torch_type_b); + auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type); + return trt_promo_type; +} + nvinfer1::ILayer* add_elementwise( ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, @@ -78,6 +86,26 @@ nvinfer1::ILayer* add_elementwise( std::swap(self, other); swapSelfOther = true; } + + if (self->getType() != other->getType()) { + LOG_DEBUG( + "Type mismatch for inputs in element-wise operation " << name << ": " << self->getType() << ", " + << other->getType()); + auto promo_type = promote_types(self->getType(), other->getType()); + if (self->getType() != promo_type) { + LOG_DEBUG( + "Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer " + << name); + self = castITensor(ctx, self, promo_type); + } + if (other->getType() != promo_type) { + LOG_DEBUG( + "Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type + << " for layer " << name); + other = castITensor(ctx, other, promo_type); + } + } + auto selfDim = util::toVec(self->getDimensions()); auto otherDim = util::toVec(other->getDimensions()); if (selfDim.size() != otherDim.size()) { diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp old mode 100644 new mode 100755 index f2770508ca..5c5966841a --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -55,10 +55,10 @@ auto element_wise_registrations TORCHTRT_UNUSED = // Should implement self + alpha * other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto scalar = args[2].unwrapToScalar().to(); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -84,10 +84,10 @@ auto element_wise_registrations TORCHTRT_UNUSED = // Should implement self + alpha * other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto scalar = args[2].unwrapToScalar().to(); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -262,12 +262,11 @@ auto element_wise_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement other - alpha * self auto self = args[0].ITensorOrFreeze(ctx); - auto otherScalar = args[1].unwrapToScalar().to(); - auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto scalar = args[2].unwrapToScalar().to(); + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -292,10 +291,10 @@ auto element_wise_registrations TORCHTRT_UNUSED = // Should implement other - alpha * self auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto scalar = args[2].unwrapToScalar().to(); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -418,7 +417,6 @@ auto element_wise_registrations TORCHTRT_UNUSED = // Should implement self * other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); @@ -433,7 +431,6 @@ auto element_wise_registrations TORCHTRT_UNUSED = // TODO: Remove with functionalization auto self = args[0].ITensorOrFreeze(ctx); auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 061eaa500e..3ecfdb2019 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -12,40 +12,29 @@ void pointwise_test_helper( std::vector shape1 = {5}, std::vector shape2 = {5}, bool negative_input = false, - bool int_tensors = false, - bool float_int_tensors = false, - bool int_float_tensors = false) { + at::ScalarType type1 = at::kFloat, + at::ScalarType type2 = at::kFloat) { auto g = std::make_shared(); torch::jit::parseIR(graph_ir, g.get()); // singleInput case is enabled when elementwise operation is performed // with an input and a constant embedded in graph std::vector torch_inputs; - if (negative_input) { - torch_inputs.push_back(at::randint(-5, 5, shape1, {at::kCUDA})); - } else { - torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA})); + int first_min = negative_input ? -5 : 1; + int first_max = 5; + int second_min = 1; + int second_max = 5; + if (type1 == at::kBool) { + first_min = 0; + first_max = 1; } - if (!singleInput) { - torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA})); + if (type2 == at::kBool) { + second_min = 0; + second_max = 1; } - - TORCHTRT_CHECK( - !((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)), - "Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true"); - - if (int_tensors) { - for (size_t i = 0UL; i < torch_inputs.size(); ++i) { - torch_inputs[i] = torch_inputs[i].to(at::kInt); - } - } else if (float_int_tensors) { - TORCHTRT_CHECK(!singleInput, "float_int_tensors tests require two inputs"); - torch_inputs[0] = torch_inputs[0].to(at::kFloat); - torch_inputs[1] = torch_inputs[1].to(at::kInt); - } else if (int_float_tensors) { - TORCHTRT_CHECK(!singleInput, "int_float_tensors tests require two inputs"); - torch_inputs[0] = torch_inputs[0].to(at::kInt); - torch_inputs[1] = torch_inputs[1].to(at::kFloat); + torch_inputs.push_back(at::randint(first_min, first_max, shape1, at::TensorOptions(at::kCUDA).dtype(type1))); + if (!singleInput) { + torch_inputs.push_back(at::randint(second_min, second_max, shape2, at::TensorOptions(at::kCUDA).dtype(type2))); } auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); @@ -78,8 +67,6 @@ TEST(Converters, ATenAddConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenAddWithAlphaConvertsCorrectly) { @@ -93,8 +80,8 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) { @@ -106,6 +93,17 @@ TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) { pointwise_test_helper(graph, false); pointwise_test_helper(graph, false, false, {3, 4}, {4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); + pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kFloat, at::kInt); +} + +TEST(Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : int = prim::Constant[value=42]() + %3 : Tensor = aten::add_(%0, %1, %2) + return (%3))IR"; + pointwise_test_helper(graph, false, false, {2, 2}, {2, 2}, false, at::kInt, at::kInt); + pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kInt, at::kInt); } TEST(Converters, ATenAddWithScalarConvertsCorrectly) { @@ -129,8 +127,8 @@ TEST(Converters, ATenSubConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenMulConvertsCorrectly) { @@ -143,8 +141,8 @@ TEST(Converters, ATenMulConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenMulWithScalarConvertsCorrectly) { @@ -162,7 +160,7 @@ TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) { %scalar : int = prim::Constant[value=2]() %1 : Tensor = aten::mul(%0, %scalar) return (%1))IR"; - pointwise_test_helper(graph, true, false, {5}, {5}, false, true); + pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt); } TEST(Converters, ATenDivConvertsCorrectly) { @@ -175,8 +173,6 @@ TEST(Converters, ATenDivConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenDivWithScalarConvertsCorrectly) { @@ -199,8 +195,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}, true); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { @@ -214,8 +210,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}, true); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) { @@ -241,8 +237,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenPowScalarConvertsCorrectly) { @@ -283,8 +279,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); - pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) { @@ -329,6 +325,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) { pointwise_test_helper(graph, false, false, {3, 4}, {4}); pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3}); + pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat); + pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kInt); } TEST(Converters, ATenRsubWithScalarConvertsCorrectly) { @@ -341,6 +339,16 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) { pointwise_test_helper(graph, true, false, {4, 3, 3, 3}); } +TEST(Converters, ATenRsubWithIntScalarConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=2]() + %scalar : int = prim::Constant[value=8]() + %3 : Tensor = aten::rsub(%0, %scalar, %2) + return (%3))IR"; + pointwise_test_helper(graph, true, false, {4, 3, 3, 3}, {}, false, at::kInt); +} + TEST(Converters, ATenClampMinConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):