Skip to content

Commit 4b3ae3a

Browse files
committed
feat: Implement lowering for aten::to.dtype schema
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent abf9b50 commit 4b3ae3a

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

core/lowering/passes/reduce_to.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
1616
graph(%x, %device, %dtype, %nb, %copy, %format):
1717
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
1818
return (%out))IR";
19+
std::string to_dtype_layout_pattern = R"IR(
20+
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
21+
%out : Tensor = aten::to.dtype_layout(%x, %device, %dtype, %layout, %nb, %copy, %format, %other)
22+
return (%out))IR";
23+
24+
std::string to_dtype_multi_input_pattern = R"IR(
25+
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
26+
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
27+
return (%out))IR";
1928

2029
std::string to_type_as_pattern = R"IR(
2130
graph(%input, %other):
@@ -34,6 +43,11 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
3443
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_dtype_pattern);
3544
map_aten_device_to_dtype.runOnGraph(graph);
3645

46+
// replace aten::to.dtype_layout with aten::to.dtype
47+
torch::jit::SubgraphRewriter map_aten_dtype_layout;
48+
map_aten_dtype_layout.RegisterRewritePattern(to_dtype_layout_pattern, to_dtype_multi_input_pattern);
49+
map_aten_dtype_layout.runOnGraph(graph);
50+
3751
// replace aten::type_as with aten::to.other
3852
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
3953
map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern);

0 commit comments

Comments
 (0)