@@ -579,4 +579,38 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
579
579
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {});
580
580
581
581
ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
582
+ }
583
+
584
+ TEST (Evaluators, AtenFormatEvaluatesCorrectly) {
585
+ const auto graph = R"IR(
586
+ graph(%x_1 : Tensor, %x_2 : Tensor):
587
+ %0 : int = prim::Constant[value=1]()
588
+ %1 : str = prim::Constant[value="res{}_{}_"]()
589
+ %2 : int = prim::Constant[value=5]()
590
+ %2.1 : int = prim::Constant[value=2]()
591
+ %3 : str = prim::Constant[value="res5_2_"]()
592
+ %4 : str = aten::format(%1, %2, %2.1)
593
+ %5 : bool = aten::eq(%3, %4)
594
+ %y : Tensor = prim::If(%5)
595
+ block0():
596
+ %194 : Tensor = aten::add(%x_1, %x_2, %0)
597
+ -> (%194)
598
+ block1():
599
+ %195 : Tensor = aten::sub(%x_1, %x_2, %0)
600
+ -> (%195)
601
+ return (%y))IR" ;
602
+ auto g = std::make_shared<torch::jit::Graph>();
603
+ torch::jit::parseIR (graph, &*g);
604
+
605
+ auto in0 = at::randint (1 , 10 , {3 , 4 }, {at::kCUDA });
606
+ auto in1 = in0.clone ();
607
+
608
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
609
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in0, in1});
610
+
611
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
612
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in0, in1});
613
+
614
+ ASSERT_TRUE (
615
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
582
616
}
0 commit comments