@@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
2121 auto padding = args[1 ].unwrapToIntList ().vec ();
2222 int64_t padSize = padding.size ();
2323 auto value = args[2 ].unwrapToScalar ().to <float >();
24-
24+ at::Tensor value_tensor = torch::tensor (value, util::TRTDataTypeToScalarType (in->getType ()));
25+ auto valueTensor = tensor_to_const (ctx, value_tensor);
2526 TORCHTRT_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
2627
2728 int64_t l_pad = padSize / 2 ;
@@ -55,10 +56,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5556 auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
5657 auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
5758 fill_layer->setInput (0 , *shape_gather_out);
58- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
59- auto valueTensor = tensor_to_const (ctx, value_tensor);
6059 fill_layer->setInput (1 , *valueTensor);
61- at::Tensor delta_tensor = torch::zeros (inRank);
60+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
6261 auto deltaTensor = tensor_to_const (ctx, delta_tensor);
6362 fill_layer->setInput (2 , *deltaTensor);
6463 auto padTensor = fill_layer->getOutput (0 );
@@ -69,10 +68,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
6968 } else {
7069 inDims.d [axis] = padding[padding_index];
7170 auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
72- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
73- auto valueTensor = tensor_to_const (ctx, value_tensor);
7471 fill_layer->setInput (1 , *valueTensor);
75- at::Tensor delta_tensor = torch::zeros (inRank);
72+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
7673 auto deltaTensor = tensor_to_const (ctx, delta_tensor);
7774 fill_layer->setInput (2 , *deltaTensor);
7875 auto padTensor = fill_layer->getOutput (0 );
@@ -112,10 +109,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
112109 auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
113110 auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
114111 fill_layer->setInput (0 , *shape_gather_out);
115- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
116- auto valueTensor = tensor_to_const (ctx, value_tensor);
117112 fill_layer->setInput (1 , *valueTensor);
118- at::Tensor delta_tensor = torch::zeros (inRank);
113+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
119114 auto deltaTensor = tensor_to_const (ctx, delta_tensor);
120115 fill_layer->setInput (2 , *deltaTensor);
121116 auto padTensor = fill_layer->getOutput (0 );
@@ -126,10 +121,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
126121 } else {
127122 inDims.d [axis] = padding[padding_index + 1 ];
128123 auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
129- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
130- auto valueTensor = tensor_to_const (ctx, value_tensor);
131124 fill_layer->setInput (1 , *valueTensor);
132- at::Tensor delta_tensor = torch::zeros (inRank);
125+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
133126 auto deltaTensor = tensor_to_const (ctx, delta_tensor);
134127 fill_layer->setInput (2 , *deltaTensor);
135128 auto padTensor = fill_layer->getOutput (0 );
0 commit comments