@@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
166
166
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
167
167
// Should implement self - alpha * other
168
168
auto self = args[0 ].ITensorOrFreeze (ctx);
169
- auto scalar = args[2 ].unwrapToScalar ().to <float >();
170
169
auto other = args[1 ].ITensorOrFreeze (ctx);
170
+ auto scalar = args[2 ].unwrapToScalar ();
171
171
172
- if (1 != scalar) {
173
- auto alphaTensor = tensor_to_const (ctx, torch::tensor ({ scalar}) );
172
+ if (1 != scalar. to < float >() ) {
173
+ auto alphaTensor = scalar_to_tensor (ctx, scalar);
174
174
auto scaleLayer = add_elementwise (
175
175
ctx,
176
176
nvinfer1::ElementWiseOperation::kPROD ,
@@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
214
214
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
215
215
// Should implement self - alpha * other
216
216
auto self = args[0 ].ITensorOrFreeze (ctx);
217
- auto scalar = args[2 ].unwrapToScalar ().to <float >();
218
217
auto other = args[1 ].ITensorOrFreeze (ctx);
218
+ auto scalar = args[2 ].unwrapToScalar ();
219
219
220
- if (1 != scalar) {
221
- auto alphaTensor = tensor_to_const (ctx, torch::tensor ({ scalar}) );
220
+ if (1 != scalar. to < float >() ) {
221
+ auto alphaTensor = scalar_to_tensor (ctx, scalar);
222
222
auto scaleLayer = add_elementwise (
223
223
ctx,
224
224
nvinfer1::ElementWiseOperation::kPROD ,
@@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
351
351
{" aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
352
352
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
353
353
auto self = args[0 ].ITensorOrFreeze (ctx);
354
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
355
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
354
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
356
355
auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
357
356
TORCHTRT_CHECK (div, " Unable to create div layer from node: " << *n);
358
357
@@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
381
380
{" aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)" ,
382
381
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
383
382
auto self = args[0 ].ITensorOrFreeze (ctx);
384
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
385
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
383
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
386
384
auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
387
385
TORCHTRT_CHECK (div, " Unable to create div layer from node: " << *n);
388
386
@@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
481
479
{" aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
482
480
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
483
481
auto self = args[0 ].ITensorOrFreeze (ctx);
484
- auto scalar = args[1 ].unwrapToScalar ();
485
- nvinfer1::ITensor* scalar_tensor;
486
- if (self->getType () == nvinfer1::DataType::kFLOAT || self->getType () == nvinfer1::DataType::kHALF ) {
487
- scalar_tensor = tensor_to_const (ctx, torch::tensor ({scalar.to <float >()}));
488
- } else {
489
- scalar_tensor = tensor_to_const (ctx, torch::tensor ({scalar.to <int >()}));
490
- }
482
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
491
483
auto equal = add_elementwise (
492
484
ctx,
493
485
nvinfer1::ElementWiseOperation::kEQUAL ,
494
486
self,
495
- scalar_tensor ,
487
+ other ,
496
488
util::node_info (n) + std::string (" is_equal" ));
497
489
TORCHTRT_CHECK (equal, " Unable to create elementwise equal layer from node: " << *n);
498
490
// XOR with ones negates and produces not_equal result
@@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
534
526
{" aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)" ,
535
527
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536
528
auto self = args[0 ].ITensorOrFreeze (ctx);
537
- auto exponentScalar = args[1 ].unwrapToScalar ().to <float >();
538
- auto exponent = tensor_to_const (ctx, torch::tensor ({exponentScalar}));
529
+ auto exponent = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
539
530
auto pow =
540
531
add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPOW , self, exponent, util::node_info (n));
541
532
TORCHTRT_CHECK (pow, " Unable to create Power layer from node: " << *n);
@@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
681
672
{" aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
682
673
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
683
674
auto self = args[0 ].ITensorOrFreeze (ctx);
684
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
685
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
675
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
686
676
if (self->getType () == nvinfer1::DataType::kBOOL ) {
677
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
687
678
if (otherScalar == 0 || otherScalar == 1 ) {
688
679
LOG_DEBUG (" Since input tensor is type bool, casting input tensor and scalar to int32" );
689
680
other = castITensor (ctx, other, nvinfer1::DataType::kINT32 );
0 commit comments