@@ -445,3 +445,227 @@ TEST(Converters, ATenRepeatExtraDimsConvertsCorrectlyWithDynamicInput) {
445
445
446
446
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
447
447
}
448
+
449
+ TEST (Converters, ATenRepeatInterleaveScalarDimConvertsCorrectly) {
450
+ const auto graph = R"IR(
451
+ graph(%x.1 : Tensor):
452
+ %2 : int = prim::Constant[value=3]()
453
+ %3 : int = prim::Constant[value=1]()
454
+ %4 : None = prim::Constant()
455
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
456
+ return (%5))IR" ;
457
+
458
+ auto g = std::make_shared<torch::jit::Graph>();
459
+
460
+ torch::jit::parseIR (graph, g.get ());
461
+
462
+ auto in = at::randint (1 , 10 , {1 , 3 }, {at::kCUDA });
463
+
464
+ auto jit_in = at::clone (in);
465
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
466
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
467
+
468
+ auto trt_in = at::clone (in);
469
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
470
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
471
+
472
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
473
+
474
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
475
+ }
476
+
477
+ TEST (Converters, ATenRepeatInterleaveScalarDimConvertsCorrectlyWithDynamicInput) {
478
+ const auto graph = R"IR(
479
+ graph(%x.1 : Tensor):
480
+ %2 : int = prim::Constant[value=3]()
481
+ %3 : int = prim::Constant[value=1]()
482
+ %4 : None = prim::Constant()
483
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
484
+ return (%5))IR" ;
485
+
486
+ auto g = std::make_shared<torch::jit::Graph>();
487
+
488
+ torch::jit::parseIR (graph, g.get ());
489
+
490
+ auto in = at::randint (1 , 10 , {1 , 3 }, {at::kCUDA });
491
+
492
+ auto jit_in = at::clone (in);
493
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
494
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
495
+
496
+ auto trt_in = at::clone (in);
497
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
498
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {trt_in});
499
+
500
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
501
+
502
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
503
+ }
504
+
505
+ TEST (Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectly) {
506
+ const auto graph = R"IR(
507
+ graph(%x.1 : Tensor):
508
+ %2 : int = prim::Constant[value=3]()
509
+ %3 : None = prim::Constant()
510
+ %4 : None = prim::Constant()
511
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
512
+ return (%5))IR" ;
513
+
514
+ auto g = std::make_shared<torch::jit::Graph>();
515
+
516
+ torch::jit::parseIR (graph, g.get ());
517
+
518
+ auto in = at::randint (1 , 10 , {1 , 3 }, {at::kCUDA });
519
+
520
+ auto jit_in = at::clone (in);
521
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
522
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
523
+
524
+ auto trt_in = at::clone (in);
525
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
526
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
527
+
528
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
529
+
530
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
531
+ }
532
+
533
+ TEST (Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectlyWithDynamicInput) {
534
+ const auto graph = R"IR(
535
+ graph(%x.1 : Tensor):
536
+ %2 : int = prim::Constant[value=3]()
537
+ %3 : None = prim::Constant()
538
+ %4 : None = prim::Constant()
539
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
540
+ return (%5))IR" ;
541
+
542
+ auto g = std::make_shared<torch::jit::Graph>();
543
+
544
+ torch::jit::parseIR (graph, g.get ());
545
+
546
+ auto in = at::randint (1 , 10 , {1 , 3 }, {at::kCUDA });
547
+
548
+ auto jit_in = at::clone (in);
549
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
550
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
551
+
552
+ auto trt_in = at::clone (in);
553
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
554
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {trt_in});
555
+
556
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
557
+
558
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
559
+ }
560
+
561
+ TEST (Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectly) {
562
+ const auto graph = R"IR(
563
+ graph(%x.1 : Tensor):
564
+ %2 : int = prim::Constant[value=3]()
565
+ %3 : int = prim::Constant[value=1]()
566
+ %4 : None = prim::Constant()
567
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
568
+ return (%5))IR" ;
569
+
570
+ auto g = std::make_shared<torch::jit::Graph>();
571
+
572
+ torch::jit::parseIR (graph, g.get ());
573
+
574
+ auto in = at::randint (1 , 10 , {2 , 3 , 2 }, {at::kCUDA });
575
+
576
+ auto jit_in = at::clone (in);
577
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
578
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
579
+
580
+ auto trt_in = at::clone (in);
581
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
582
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
583
+
584
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
585
+
586
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
587
+ }
588
+
589
+ TEST (Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectlyWithDynamicInput) {
590
+ const auto graph = R"IR(
591
+ graph(%x.1 : Tensor):
592
+ %2 : int = prim::Constant[value=3]()
593
+ %3 : int = prim::Constant[value=1]()
594
+ %4 : None = prim::Constant()
595
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
596
+ return (%5))IR" ;
597
+
598
+ auto g = std::make_shared<torch::jit::Graph>();
599
+
600
+ torch::jit::parseIR (graph, g.get ());
601
+
602
+ auto in = at::randint (1 , 10 , {2 , 3 , 2 }, {at::kCUDA });
603
+
604
+ auto jit_in = at::clone (in);
605
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
606
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
607
+
608
+ auto trt_in = at::clone (in);
609
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
610
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {trt_in});
611
+
612
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
613
+
614
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
615
+ }
616
+
617
+ TEST (Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectly) {
618
+ const auto graph = R"IR(
619
+ graph(%x.1 : Tensor):
620
+ %2 : int = prim::Constant[value=3]()
621
+ %3 : None = prim::Constant()
622
+ %4 : None = prim::Constant()
623
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
624
+ return (%5))IR" ;
625
+
626
+ auto g = std::make_shared<torch::jit::Graph>();
627
+
628
+ torch::jit::parseIR (graph, g.get ());
629
+
630
+ auto in = at::randint (1 , 10 , {2 , 3 , 2 }, {at::kCUDA });
631
+
632
+ auto jit_in = at::clone (in);
633
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
634
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
635
+
636
+ auto trt_in = at::clone (in);
637
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
638
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
639
+
640
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
641
+
642
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
643
+ }
644
+
645
+ TEST (Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicInput) {
646
+ const auto graph = R"IR(
647
+ graph(%x.1 : Tensor):
648
+ %2 : int = prim::Constant[value=3]()
649
+ %3 : None = prim::Constant()
650
+ %4 : None = prim::Constant()
651
+ %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
652
+ return (%5))IR" ;
653
+
654
+ auto g = std::make_shared<torch::jit::Graph>();
655
+
656
+ torch::jit::parseIR (graph, g.get ());
657
+
658
+ auto in = at::randint (1 , 10 , {2 , 3 , 2 }, {at::kCUDA });
659
+
660
+ auto jit_in = at::clone (in);
661
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
662
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
663
+
664
+ auto trt_in = at::clone (in);
665
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
666
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {trt_in});
667
+
668
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
669
+
670
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
671
+ }
0 commit comments