Skip to content

Commit f350699

Browse files
authored
feat (//core/conversion) : Add converter for torch.repeat_interleave ( (#1313)
* added interleave_repeat int repeats converter * fixed compile time errors * added repeat_interleave tests, moved converter to expand file * repeat_interleave passing tests for static input * implementation and tests for dynamic input repeat_interleave * dynamic shape checks * reformatting
1 parent f921c35 commit f350699

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,116 @@ auto expand_registrations TORCHTRT_UNUSED =
282282
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283283

284284
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285+
return true;
286+
}})
287+
.pattern(
288+
{"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)",
289+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
290+
auto self = args[0].ITensorOrFreeze(ctx);
291+
auto repeats = args[1].unwrapToScalar().to<int>();
292+
293+
auto input_shape = self->getDimensions();
294+
295+
int dim;
296+
if (args[2].IValue()->isNone()) {
297+
dim = 0;
298+
299+
// Flatten self tensor
300+
int size;
301+
if (ctx->input_is_dynamic) {
302+
// Set size to -1 if input is dynamic
303+
size = -1;
304+
} else {
305+
size = 1;
306+
for (int i = 0; i < input_shape.nbDims; i++) {
307+
size *= input_shape.d[i];
308+
}
309+
}
310+
auto flatten = ctx->net->addShuffle(*self);
311+
TORCHTRT_CHECK(flatten, "Unable to create shuffle layer from node: " << *n);
312+
flatten->setReshapeDimensions(util::toDims(std::vector<int64_t>({size})));
313+
self = flatten->getOutput(0);
314+
input_shape = self->getDimensions();
315+
} else {
316+
dim = args[2].unwrapToScalar().to<int>();
317+
}
318+
319+
if (ctx->input_is_dynamic) {
320+
int dynamic_dims = 0;
321+
for (int idx = 0; idx < input_shape.nbDims; idx++) {
322+
if (input_shape.d[idx] == -1) {
323+
dynamic_dims++;
324+
}
325+
}
326+
327+
if (dynamic_dims > 1) {
328+
TORCHTRT_THROW_ERROR(
329+
"Repeat_interleave is currently not supported when target shape contains more than one dynamic dimension");
330+
}
331+
}
332+
333+
// Insert singleton dimension after desired repeat dimension
334+
std::vector<int64_t> repeat_shape_vec;
335+
for (int j = 0; j < input_shape.nbDims; j++) {
336+
repeat_shape_vec.push_back(input_shape.d[j]);
337+
if (j == dim) {
338+
repeat_shape_vec.push_back(1);
339+
}
340+
}
341+
auto expand = ctx->net->addShuffle(*self);
342+
TORCHTRT_CHECK(expand, "Unable to create shuffle layer from node: " << *n);
343+
auto repeat_shape_dims = util::toDims(repeat_shape_vec);
344+
expand->setReshapeDimensions(repeat_shape_dims);
345+
346+
// Expand on newly created singleton dimension
347+
repeat_shape_dims.d[dim + 1] = repeats;
348+
std::vector<int64_t> start_vec(repeat_shape_dims.nbDims, 0);
349+
auto start_dims = util::toDims(start_vec);
350+
351+
std::vector<int64_t> strides_vec(repeat_shape_dims.nbDims, 1);
352+
strides_vec[dim + 1] = 0;
353+
auto strides_dims = util::toDims(strides_vec);
354+
355+
auto slice = ctx->net->addSlice(*expand->getOutput(0), start_dims, repeat_shape_dims, strides_dims);
356+
357+
if (ctx->input_is_dynamic) {
358+
auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32));
359+
360+
auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0);
361+
std::vector<int64_t> repeat_const_vec(repeat_shape_dims.nbDims, 1);
362+
repeat_const_vec[dim + 1] = repeats;
363+
auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32));
364+
auto repeat_shape_tensor =
365+
ctx->net
366+
->addElementWise(*expand_output_shape, *repeat_const, nvinfer1::ElementWiseOperation::kPROD)
367+
->getOutput(0);
368+
369+
auto strides_tensor = tensor_to_const(ctx, torch::tensor(strides_vec, torch::kInt32));
370+
slice->setInput(1, *start_tensor);
371+
slice->setInput(2, *repeat_shape_tensor);
372+
slice->setInput(3, *strides_tensor);
373+
}
374+
375+
// Collapse repeated dimension back into desired dimension
376+
std::vector<int64_t> collapse_shape_vec;
377+
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
378+
if (k == dim) {
379+
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
380+
// Set dim size to -1 if repeat is being done on dynamic dim
381+
collapse_dim = std::max(collapse_dim, (int64_t)-1);
382+
collapse_shape_vec.push_back(collapse_dim);
383+
} else {
384+
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
385+
}
386+
}
387+
auto collapse = ctx->net->addShuffle(*slice->getOutput(0));
388+
TORCHTRT_CHECK(collapse, "Unable to create shuffle layer from node: " << *n);
389+
collapse->setReshapeDimensions(util::toDims(collapse_shape_vec));
390+
391+
collapse->setName(util::node_info(n).c_str());
392+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0));
393+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
394+
285395
return true;
286396
}});
287397

tests/core/conversion/converters/test_expand.cpp

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,227 @@ TEST(Converters, ATenRepeatExtraDimsConvertsCorrectlyWithDynamicInput) {
445445

446446
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
447447
}
448+
449+
TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectly) {
450+
const auto graph = R"IR(
451+
graph(%x.1 : Tensor):
452+
%2 : int = prim::Constant[value=3]()
453+
%3 : int = prim::Constant[value=1]()
454+
%4 : None = prim::Constant()
455+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
456+
return (%5))IR";
457+
458+
auto g = std::make_shared<torch::jit::Graph>();
459+
460+
torch::jit::parseIR(graph, g.get());
461+
462+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
463+
464+
auto jit_in = at::clone(in);
465+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
466+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
467+
468+
auto trt_in = at::clone(in);
469+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
470+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
471+
472+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
473+
474+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
475+
}
476+
477+
TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectlyWithDynamicInput) {
478+
const auto graph = R"IR(
479+
graph(%x.1 : Tensor):
480+
%2 : int = prim::Constant[value=3]()
481+
%3 : int = prim::Constant[value=1]()
482+
%4 : None = prim::Constant()
483+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
484+
return (%5))IR";
485+
486+
auto g = std::make_shared<torch::jit::Graph>();
487+
488+
torch::jit::parseIR(graph, g.get());
489+
490+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
491+
492+
auto jit_in = at::clone(in);
493+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
494+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
495+
496+
auto trt_in = at::clone(in);
497+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
498+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
499+
500+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
501+
502+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
503+
}
504+
505+
TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectly) {
506+
const auto graph = R"IR(
507+
graph(%x.1 : Tensor):
508+
%2 : int = prim::Constant[value=3]()
509+
%3 : None = prim::Constant()
510+
%4 : None = prim::Constant()
511+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
512+
return (%5))IR";
513+
514+
auto g = std::make_shared<torch::jit::Graph>();
515+
516+
torch::jit::parseIR(graph, g.get());
517+
518+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
519+
520+
auto jit_in = at::clone(in);
521+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
522+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
523+
524+
auto trt_in = at::clone(in);
525+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
526+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
527+
528+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
529+
530+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
531+
}
532+
533+
TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectlyWithDynamicInput) {
534+
const auto graph = R"IR(
535+
graph(%x.1 : Tensor):
536+
%2 : int = prim::Constant[value=3]()
537+
%3 : None = prim::Constant()
538+
%4 : None = prim::Constant()
539+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
540+
return (%5))IR";
541+
542+
auto g = std::make_shared<torch::jit::Graph>();
543+
544+
torch::jit::parseIR(graph, g.get());
545+
546+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
547+
548+
auto jit_in = at::clone(in);
549+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
550+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
551+
552+
auto trt_in = at::clone(in);
553+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
554+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
555+
556+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
557+
558+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
559+
}
560+
561+
TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectly) {
562+
const auto graph = R"IR(
563+
graph(%x.1 : Tensor):
564+
%2 : int = prim::Constant[value=3]()
565+
%3 : int = prim::Constant[value=1]()
566+
%4 : None = prim::Constant()
567+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
568+
return (%5))IR";
569+
570+
auto g = std::make_shared<torch::jit::Graph>();
571+
572+
torch::jit::parseIR(graph, g.get());
573+
574+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
575+
576+
auto jit_in = at::clone(in);
577+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
578+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
579+
580+
auto trt_in = at::clone(in);
581+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
582+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
583+
584+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
585+
586+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
587+
}
588+
589+
TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectlyWithDynamicInput) {
590+
const auto graph = R"IR(
591+
graph(%x.1 : Tensor):
592+
%2 : int = prim::Constant[value=3]()
593+
%3 : int = prim::Constant[value=1]()
594+
%4 : None = prim::Constant()
595+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
596+
return (%5))IR";
597+
598+
auto g = std::make_shared<torch::jit::Graph>();
599+
600+
torch::jit::parseIR(graph, g.get());
601+
602+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
603+
604+
auto jit_in = at::clone(in);
605+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
606+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
607+
608+
auto trt_in = at::clone(in);
609+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
610+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
611+
612+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
613+
614+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
615+
}
616+
617+
TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectly) {
618+
const auto graph = R"IR(
619+
graph(%x.1 : Tensor):
620+
%2 : int = prim::Constant[value=3]()
621+
%3 : None = prim::Constant()
622+
%4 : None = prim::Constant()
623+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
624+
return (%5))IR";
625+
626+
auto g = std::make_shared<torch::jit::Graph>();
627+
628+
torch::jit::parseIR(graph, g.get());
629+
630+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
631+
632+
auto jit_in = at::clone(in);
633+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
634+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
635+
636+
auto trt_in = at::clone(in);
637+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
638+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
639+
640+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
641+
642+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
643+
}
644+
645+
TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicInput) {
646+
const auto graph = R"IR(
647+
graph(%x.1 : Tensor):
648+
%2 : int = prim::Constant[value=3]()
649+
%3 : None = prim::Constant()
650+
%4 : None = prim::Constant()
651+
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
652+
return (%5))IR";
653+
654+
auto g = std::make_shared<torch::jit::Graph>();
655+
656+
torch::jit::parseIR(graph, g.get());
657+
658+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
659+
660+
auto jit_in = at::clone(in);
661+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
662+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
663+
664+
auto trt_in = at::clone(in);
665+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
666+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
667+
668+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
669+
670+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
671+
}

0 commit comments

Comments
 (0)