Skip to content

Commit bde8ee0

Browse files
committed
feat: Implement test case for aten::to.dtype lowering
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4b3ae3a commit bde8ee0

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

core/lowering/passes/reduce_to.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
1717
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
1818
return (%out))IR";
1919
std::string to_dtype_layout_pattern = R"IR(
20-
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
21-
%out : Tensor = aten::to.dtype_layout(%x, %device, %dtype, %layout, %nb, %copy, %format, %other)
20+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
21+
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
2222
return (%out))IR";
2323

2424
std::string to_dtype_multi_input_pattern = R"IR(
25-
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
26-
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
25+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
26+
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
2727
return (%out))IR";
2828

2929
std::string to_type_as_pattern = R"IR(

tests/core/lowering/test_reduce_to_pass.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ TEST(LoweringPasses, ReduceToCorrectly) {
2828
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
2929
}
3030

31+
TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
32+
std::string source_graph = R"IR(
33+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
34+
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
35+
return (%out))IR";
36+
std::string target_graph = R"IR(
37+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
38+
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
39+
return (%out))IR";
40+
41+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
42+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
43+
auto sg = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(source_graph, &*sg);
45+
torch_tensorrt::core::lowering::passes::ReduceToOperation(sg);
46+
47+
auto tg = std::make_shared<torch::jit::Graph>();
48+
torch::jit::parseIR(target_graph, &*tg);
49+
50+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
51+
}
52+
3153
TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
3254
std::string source_graph = R"IR(
3355
graph(%input, %other):

0 commit comments

Comments
 (0)