@@ -198,7 +198,7 @@ auto expand_registrations TRTORCH_UNUSED =
198198 RegisterNodeConversionPatterns ()
199199 .pattern({" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
200200 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201- auto in = args[0 ].ITensor ( );
201+ auto in = args[0 ].ITensorOrFreeze (ctx );
202202 auto input_dims = in->getDimensions ();
203203 auto expanded_size = args[1 ].unwrapToIntList ();
204204 auto expandedDims = util::toDims (expanded_size);
@@ -213,9 +213,9 @@ auto expand_registrations TRTORCH_UNUSED =
213213 }})
214214 .pattern({" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
215215 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216- auto in = args[0 ].ITensor ( );
216+ auto in = args[0 ].ITensorOrFreeze (ctx );
217217 auto input_dims = in->getDimensions ();
218- auto targetTensor = args[1 ].ITensor ( );
218+ auto targetTensor = args[1 ].ITensorOrFreeze (ctx );
219219 auto targetDims = targetTensor->getDimensions ();
220220 LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221221 if (ctx->input_is_dynamic ) {
@@ -227,7 +227,7 @@ auto expand_registrations TRTORCH_UNUSED =
227227 }})
228228 .pattern({" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
229229 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230- auto in = args[0 ].ITensor ( );
230+ auto in = args[0 ].ITensorOrFreeze (ctx );
231231 auto input_dims = in->getDimensions ();
232232 auto repeats = args[1 ].unwrapToIntList ().vec ();
233233 int repeats_rank = repeats.size ();
0 commit comments