@@ -25,6 +25,27 @@ nvinfer1::ITensor* clamp_util(
25
25
return clamp_layer_out;
26
26
}
27
27
28
+ nvinfer1::ITensor* scalar_to_tensor (ConversionCtx* ctx, at::Scalar s) {
29
+ nvinfer1::ITensor* out;
30
+ if (s.isIntegral (false )) {
31
+ auto s_int = s.to <int64_t >();
32
+ auto s_t = torch::tensor ({s_int}).to (at::kInt );
33
+ out = tensor_to_const (ctx, s_t );
34
+ } else if (s.isBoolean ()) {
35
+ auto s_bool = s.to <bool >();
36
+ auto s_t = torch::tensor ({s_bool}).to (at::kBool );
37
+ out = tensor_to_const (ctx, s_t );
38
+ } else if (s.isFloatingPoint ()) {
39
+ auto other_float = s.to <float >();
40
+ auto s_t = torch::tensor ({other_float});
41
+ out = tensor_to_const (ctx, s_t );
42
+ } else {
43
+ out = nullptr ;
44
+ TRTORCH_THROW_ERROR (" Unsupported data type for scalar. Found: (" << s.type () << " )" );
45
+ }
46
+ return out;
47
+ }
48
+
28
49
auto element_wise_registrations TRTORCH_UNUSED =
29
50
RegisterNodeConversionPatterns ()
30
51
.pattern({" aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
@@ -557,8 +578,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
557
578
.pattern({" aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
558
579
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
559
580
auto self = args[0 ].ITensorOrFreeze (ctx);
560
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
561
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
581
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
582
+ if (self->getType () != other->getType ()) {
583
+ other = castITensor (ctx, other, self->getType ());
584
+ }
562
585
auto gt =
563
586
add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n));
564
587
TRTORCH_CHECK (gt, " Unable to create greater layer from node: " << *n);
@@ -584,8 +607,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
584
607
.pattern({" aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
585
608
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
586
609
auto self = args[0 ].ITensorOrFreeze (ctx);
587
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
588
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
610
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
611
+ if (self->getType () != other->getType ()) {
612
+ other = castITensor (ctx, other, self->getType ());
613
+ }
589
614
auto lt =
590
615
add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n));
591
616
TRTORCH_CHECK (lt, " Unable to create less layer from node: " << *n);
@@ -613,6 +638,18 @@ auto element_wise_registrations TRTORCH_UNUSED =
613
638
auto self = args[0 ].ITensorOrFreeze (ctx);
614
639
auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
615
640
auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
641
+ if (self->getType () == nvinfer1::DataType::kBOOL ) {
642
+ if (otherScalar == 0 || otherScalar == 1 ) {
643
+ LOG_DEBUG (" Since input tensor is type bool, casting input tensor and scalar to int32" );
644
+ other = castITensor (ctx, other, nvinfer1::DataType::kINT32 );
645
+ self = castITensor (ctx, self, nvinfer1::DataType::kINT32 );
646
+ } else {
647
+ LOG_WARNING (" Input Tensor has type bool, but scalar is not 0 or 1. Found: " << otherScalar);
648
+ }
649
+ }
650
+ if (self->getType () != other->getType ()) {
651
+ other = castITensor (ctx, other, self->getType ());
652
+ }
616
653
auto eq =
617
654
add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n));
618
655
TRTORCH_CHECK (eq, " Unable to create equal layer from node: " << *n);
@@ -648,8 +685,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
648
685
.pattern({" aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
649
686
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
650
687
auto self = args[0 ].ITensorOrFreeze (ctx);
651
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
652
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
688
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
689
+ if (self->getType () != other->getType ()) {
690
+ other = castITensor (ctx, other, self->getType ());
691
+ }
653
692
654
693
auto greater = add_elementwise (
655
694
ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n) + " _greater" );
@@ -695,8 +734,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
695
734
.pattern({" aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
696
735
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
697
736
auto self = args[0 ].ITensorOrFreeze (ctx);
698
- auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
699
- auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
737
+ auto other = scalar_to_tensor (ctx, args[1 ].unwrapToScalar ());
738
+ if (self->getType () != other->getType ()) {
739
+ other = castITensor (ctx, other, self->getType ());
740
+ }
700
741
701
742
auto less = add_elementwise (
702
743
ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n) + " _less" );
0 commit comments