diff --git a/core/conversion/converters/impl/constant_pad.cpp b/core/conversion/converters/impl/constant_pad.cpp index 679a23f875..6d3f1ab609 100644 --- a/core/conversion/converters/impl/constant_pad.cpp +++ b/core/conversion/converters/impl/constant_pad.cpp @@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns auto padding = args[1].unwrapToIntList().vec(); int64_t padSize = padding.size(); auto value = args[2].unwrapToScalar().to(); - + at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType())); + auto valueTensor = tensor_to_const(ctx, value_tensor); TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize); int64_t l_pad = padSize / 2; @@ -55,10 +56,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE); auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0); fill_layer->setInput(0, *shape_gather_out); - at::Tensor value_tensor = torch::tensor(value, torch::kFloat32); - auto valueTensor = tensor_to_const(ctx, value_tensor); fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank); + at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); auto deltaTensor = tensor_to_const(ctx, delta_tensor); fill_layer->setInput(2, *deltaTensor); auto padTensor = fill_layer->getOutput(0); @@ -69,10 +68,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns } else { inDims.d[axis] = padding[padding_index]; auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE); - at::Tensor value_tensor = torch::tensor(value, torch::kFloat32); - auto valueTensor = tensor_to_const(ctx, value_tensor); fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank); + at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); auto deltaTensor = tensor_to_const(ctx, delta_tensor); fill_layer->setInput(2, *deltaTensor); auto padTensor = fill_layer->getOutput(0); @@ -112,10 +109,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE); auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0); fill_layer->setInput(0, *shape_gather_out); - at::Tensor value_tensor = torch::tensor(value, torch::kFloat32); - auto valueTensor = tensor_to_const(ctx, value_tensor); fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank); + at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); auto deltaTensor = tensor_to_const(ctx, delta_tensor); fill_layer->setInput(2, *deltaTensor); auto padTensor = fill_layer->getOutput(0); @@ -126,10 +121,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns } else { inDims.d[axis] = padding[padding_index + 1]; auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE); - at::Tensor value_tensor = torch::tensor(value, torch::kFloat32); - auto valueTensor = tensor_to_const(ctx, value_tensor); fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank); + at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); auto deltaTensor = tensor_to_const(ctx, delta_tensor); fill_layer->setInput(2, *deltaTensor); auto padTensor = fill_layer->getOutput(0); diff --git a/tests/core/conversion/converters/test_constant_pad.cpp b/tests/core/conversion/converters/test_constant_pad.cpp index 9b37be4352..c5f0bd8a31 100644 --- a/tests/core/conversion/converters/test_constant_pad.cpp +++ b/tests/core/conversion/converters/test_constant_pad.cpp @@ -28,6 +28,29 @@ TEST(Converters, ATenConstantPad1dTensorConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenConstantPad1dIntTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3]]() + %2 : Scalar = prim::Constant[value=2]() + %3 : Tensor = aten::constant_pad_nd(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}).toType(at::kInt); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + TEST(Converters, ATenConstantPad1dRightZeroTensorConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):