Skip to content

Commit cde5e55

Browse files
authored
Merge pull request #1448 from Njuapp/scalar_to_tensor
scalar_to_tensor avoid scalar.to<float>()
2 parents a029c2a + 28eb274 commit cde5e55

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
166166
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
167167
// Should implement self - alpha * other
168168
auto self = args[0].ITensorOrFreeze(ctx);
169-
auto scalar = args[2].unwrapToScalar().to<float>();
170169
auto other = args[1].ITensorOrFreeze(ctx);
170+
auto scalar = args[2].unwrapToScalar();
171171

172-
if (1 != scalar) {
173-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
172+
if (1 != scalar.to<float>()) {
173+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
174174
auto scaleLayer = add_elementwise(
175175
ctx,
176176
nvinfer1::ElementWiseOperation::kPROD,
@@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
214214
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
215215
// Should implement self - alpha * other
216216
auto self = args[0].ITensorOrFreeze(ctx);
217-
auto scalar = args[2].unwrapToScalar().to<float>();
218217
auto other = args[1].ITensorOrFreeze(ctx);
218+
auto scalar = args[2].unwrapToScalar();
219219

220-
if (1 != scalar) {
221-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
220+
if (1 != scalar.to<float>()) {
221+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
222222
auto scaleLayer = add_elementwise(
223223
ctx,
224224
nvinfer1::ElementWiseOperation::kPROD,
@@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
351351
{"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
352352
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
353353
auto self = args[0].ITensorOrFreeze(ctx);
354-
auto otherScalar = args[1].unwrapToScalar().to<float>();
355-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
354+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
356355
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
357356
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
358357

@@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
381380
{"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
382381
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
383382
auto self = args[0].ITensorOrFreeze(ctx);
384-
auto otherScalar = args[1].unwrapToScalar().to<float>();
385-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
383+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
386384
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
387385
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
388386

@@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
481479
{"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
482480
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
483481
auto self = args[0].ITensorOrFreeze(ctx);
484-
auto scalar = args[1].unwrapToScalar();
485-
nvinfer1::ITensor* scalar_tensor;
486-
if (self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF) {
487-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
488-
} else {
489-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
490-
}
482+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
491483
auto equal = add_elementwise(
492484
ctx,
493485
nvinfer1::ElementWiseOperation::kEQUAL,
494486
self,
495-
scalar_tensor,
487+
other,
496488
util::node_info(n) + std::string("is_equal"));
497489
TORCHTRT_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n);
498490
// XOR with ones negates and produces not_equal result
@@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
534526
{"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
535527
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536528
auto self = args[0].ITensorOrFreeze(ctx);
537-
auto exponentScalar = args[1].unwrapToScalar().to<float>();
538-
auto exponent = tensor_to_const(ctx, torch::tensor({exponentScalar}));
529+
auto exponent = scalar_to_tensor(ctx, args[1].unwrapToScalar());
539530
auto pow =
540531
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n));
541532
TORCHTRT_CHECK(pow, "Unable to create Power layer from node: " << *n);
@@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
681672
{"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
682673
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
683674
auto self = args[0].ITensorOrFreeze(ctx);
684-
auto otherScalar = args[1].unwrapToScalar().to<float>();
685-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
675+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
686676
if (self->getType() == nvinfer1::DataType::kBOOL) {
677+
auto otherScalar = args[1].unwrapToScalar().to<float>();
687678
if (otherScalar == 0 || otherScalar == 1) {
688679
LOG_DEBUG("Since input tensor is type bool, casting input tensor and scalar to int32");
689680
other = castITensor(ctx, other, nvinfer1::DataType::kINT32);

0 commit comments

Comments
 (0)