@@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
166166 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
167167 // Should implement self - alpha * other
168168 auto self = args[0 ].ITensorOrFreeze (ctx);
169- auto scalar = args[2 ].unwrapToScalar ().to <float >();
170169 auto other = args[1 ].ITensorOrFreeze (ctx);
170+ auto scalar = args[2 ].unwrapToScalar ();
171171
172- if (1 != scalar) {
173- auto alphaTensor = tensor_to_const (ctx, torch::tensor ({ scalar}) );
172+ if (1 != scalar. to < float >() ) {
173+ auto alphaTensor = scalar_to_tensor (ctx, scalar);
174174 auto scaleLayer = add_elementwise (
175175 ctx,
176176 nvinfer1::ElementWiseOperation::kPROD ,
@@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
214214 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
215215 // Should implement self - alpha * other
216216 auto self = args[0 ].ITensorOrFreeze (ctx);
217- auto scalar = args[2 ].unwrapToScalar ().to <float >();
218217 auto other = args[1 ].ITensorOrFreeze (ctx);
218+ auto scalar = args[2 ].unwrapToScalar ();
219219
220- if (1 != scalar) {
221- auto alphaTensor = tensor_to_const (ctx, torch::tensor ({ scalar}) );
220+ if (1 != scalar. to < float >() ) {
221+ auto alphaTensor = scalar_to_tensor (ctx, scalar);
222222 auto scaleLayer = add_elementwise (
223223 ctx,
224224 nvinfer1::ElementWiseOperation::kPROD ,
@@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
351351 {" aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
352352 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
353353 auto self = args[0 ].ITensorOrFreeze (ctx);
354- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
355- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
354+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
356355 auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
357356 TORCHTRT_CHECK (div, " Unable to create div layer from node: " << *n);
358357
@@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
381380 {" aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)" ,
382381 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
383382 auto self = args[0 ].ITensorOrFreeze (ctx);
384- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
385- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
383+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
386384 auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
387385 TORCHTRT_CHECK (div, " Unable to create div layer from node: " << *n);
388386
@@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
481479 {" aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
482480 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
483481 auto self = args[0 ].ITensorOrFreeze (ctx);
484- auto scalar = args[1 ].unwrapToScalar ();
485- nvinfer1::ITensor* scalar_tensor;
486- if (self->getType () == nvinfer1::DataType::kFLOAT || self->getType () == nvinfer1::DataType::kHALF ) {
487- scalar_tensor = tensor_to_const (ctx, torch::tensor ({scalar.to <float >()}));
488- } else {
489- scalar_tensor = tensor_to_const (ctx, torch::tensor ({scalar.to <int >()}));
490- }
482+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
491483 auto equal = add_elementwise (
492484 ctx,
493485 nvinfer1::ElementWiseOperation::kEQUAL ,
494486 self,
495- scalar_tensor ,
487+ other ,
496488 util::node_info (n) + std::string (" is_equal" ));
497489 TORCHTRT_CHECK (equal, " Unable to create elementwise equal layer from node: " << *n);
498490 // XOR with ones negates and produces not_equal result
@@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
534526 {" aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)" ,
535527 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536528 auto self = args[0 ].ITensorOrFreeze (ctx);
537- auto exponentScalar = args[1 ].unwrapToScalar ().to <float >();
538- auto exponent = tensor_to_const (ctx, torch::tensor ({exponentScalar}));
529+ auto exponent = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
539530 auto pow =
540531 add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPOW , self, exponent, util::node_info (n));
541532 TORCHTRT_CHECK (pow, " Unable to create Power layer from node: " << *n);
@@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
681672 {" aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
682673 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
683674 auto self = args[0 ].ITensorOrFreeze (ctx);
684- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
685- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
675+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
686676 if (self->getType () == nvinfer1::DataType::kBOOL ) {
677+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
687678 if (otherScalar == 0 || otherScalar == 1 ) {
688679 LOG_DEBUG (" Since input tensor is type bool, casting input tensor and scalar to int32" );
689680 other = castITensor (ctx, other, nvinfer1::DataType::kINT32 );
0 commit comments