1
1
#include " core/conversion/converters/converters.h"
2
2
#include " core/util/prelude.h"
3
3
4
+ #include " torch/torch.h"
5
+
4
6
namespace torch_tensorrt {
5
7
namespace core {
6
8
namespace conversion {
7
9
namespace converters {
8
10
namespace impl {
9
11
namespace {
10
12
13
+
14
+ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
15
+ {" aten::abs (Tensor self) -> Tensor" ,
16
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17
+ auto in = args[0 ].ITensor ();
18
+ bool unary_supported_input = in->getType () == nvinfer1::DataType::kFLOAT
19
+ || in->getType () == nvinfer1::DataType::kHALF
20
+ || in->getType () == nvinfer1::DataType::kINT8 ;
21
+ if (unary_supported_input){
22
+ auto unary_layer = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kABS );
23
+ TORCHTRT_CHECK (unary_layer, " Unable to create abs layer from node: " << *n);
24
+ unary_layer->setName (util::node_info (n).c_str ());
25
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], unary_layer->getOutput (0 ));
26
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
27
+ return true ;
28
+ }
29
+ else {
30
+ // For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
31
+ at::Tensor neg_one = torch::full ({1 }, -1 ).to (util::TRTDataTypeToScalarType (in->getType ()));
32
+ auto neg_one_const = tensor_to_const (ctx, neg_one);
33
+ auto neg_layer = add_elementwise (
34
+ ctx,
35
+ nvinfer1::ElementWiseOperation::kPROD ,
36
+ in,
37
+ neg_one_const,
38
+ util::node_info (n) + std::string (" _Negation" ));
39
+ TORCHTRT_CHECK (neg_layer, " Unable to create prod layer from node: " << *n);
40
+ auto max_layer = add_elementwise (
41
+ ctx,
42
+ nvinfer1::ElementWiseOperation::kMAX ,
43
+ in,
44
+ neg_layer->getOutput (0 ),
45
+ util::node_info (n) + std::string (" _Max" ));
46
+ TORCHTRT_CHECK (max_layer, " Unable to create max layer from node: " << *n);
47
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], max_layer->getOutput (0 ));
48
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
49
+ return true ;
50
+ }
51
+ }});
52
+
11
53
#define convert (unary, trt_type ) \
12
54
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
13
55
{" aten::" #unary " (Tensor self) -> Tensor" , \
@@ -32,7 +74,6 @@ convert(asin, kASIN);
32
74
convert (sinh, kSINH );
33
75
convert (tan, kTAN );
34
76
convert (atan, kATAN );
35
- convert (abs, kABS );
36
77
convert (floor, kFLOOR );
37
78
convert (reciprocal, kRECIP );
38
79
convert (log, kLOG );
0 commit comments