@@ -71,35 +71,34 @@ auto select_registrations TRTORCH_UNUSED =
7171 RegisterNodeConversionPatterns ()
7272 .pattern({" aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
7373 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
74- auto in = args[0 ].ITensor ( );
74+ auto in = args[0 ].ITensorOrFreeze (ctx );
7575 auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
7676 auto axis = args[1 ].unwrapToInt ();
7777 axis = axis < 0 ? axis + maxDim : axis;
7878 auto ind = (int32_t )args[2 ].unwrapToInt ();
7979
8080 // index to access needs to be an at::Tensor
8181 at::Tensor indices = torch::tensor ({ind}).to (torch::kI32 );
82- auto weights = Weights (ctx, indices);
83-
84- // IConstantLayer to convert indices from Weights to ITensor
85- auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
86- TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
87- auto const_out = const_layer->getOutput (0 );
82+ auto const_out = tensor_to_const (ctx, indices);
8883
8984 // IGatherLayer takes in input tensor, the indices, and the axis
9085 // of input tensor to take indices from
9186 auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
9287 TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
93- auto gather_out = gather_layer->getOutput (0 );
88+ auto out = gather_layer->getOutput (0 );
9489
95- // IShuffleLayer removes redundant dimensions
96- auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
97- TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
98- shuffle_layer->setReshapeDimensions (util::squeezeDims (gather_out->getDimensions (), axis));
99- shuffle_layer->setName (util::node_info (n).c_str ());
100- auto shuffle_out = shuffle_layer->getOutput (0 );
90+ LOG_DEBUG (" Gather tensor shape: " << out->getDimensions ());
10191
102- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
92+ if (out->getDimensions ().nbDims != 1 ) {
93+ // IShuffleLayer removes redundant dimensions
94+ auto shuffle_layer = ctx->net ->addShuffle (*out);
95+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
96+ shuffle_layer->setReshapeDimensions (util::squeezeDims (out->getDimensions (), axis));
97+ shuffle_layer->setName (util::node_info (n).c_str ());
98+ out = shuffle_layer->getOutput (0 );
99+ }
100+
101+ out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out);
103102
104103 LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
105104
@@ -253,15 +252,14 @@ auto select_registrations TRTORCH_UNUSED =
253252 " aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)" ,
254253 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
255254 auto self = args[0 ].ITensorOrFreeze (ctx);
256- LOG_DEBUG (args[1 ].unwrapToTensor ());
257255 auto mask = castITensor (ctx, args[1 ].ITensorOrFreeze (ctx), nvinfer1::DataType::kBOOL );
256+ mask = addPadding (ctx, n, mask, self->getDimensions ().nbDims , false , true );
258257 auto val = args[2 ].unwrapToScalar ().to <float >();
259- LOG_DEBUG (torch::full (util::toVec (self->getDimensions ()), val));
260258 auto val_t = tensor_to_const (ctx, torch::full (util::toVec (self->getDimensions ()), val));
261259
262260 TRTORCH_CHECK (util::broadcastable (self->getDimensions (), mask->getDimensions (), /* multidirectional=*/ false ), " Self and mask tensors are not broadcastable" );
263261
264- auto new_layer = ctx->net ->addSelect (*mask, *self , *val_t );
262+ auto new_layer = ctx->net ->addSelect (*mask, *val_t , *self );
265263 TRTORCH_CHECK (new_layer, " Unable to create layer for aten::masked_fill" );
266264
267265 new_layer->setName (util::node_info (n).c_str ());
0 commit comments