@@ -16,6 +16,15 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
1616 graph(%x, %device, %dtype, %nb, %copy, %format):
1717 %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
1818 return (%out))IR" ;
19+ std::string to_dtype_layout_pattern = R"IR(
20+ graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
21+ %out : Tensor = aten::to.dtype_layout(%x, %device, %dtype, %layout, %nb, %copy, %format, %other)
22+ return (%out))IR" ;
23+
24+ std::string to_dtype_multi_input_pattern = R"IR(
25+ graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
26+ %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
27+ return (%out))IR" ;
1928
2029 std::string to_type_as_pattern = R"IR(
2130 graph(%input, %other):
@@ -34,6 +43,11 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
3443 map_aten_device_to_dtype.RegisterRewritePattern (to_device_pattern, to_dtype_pattern);
3544 map_aten_device_to_dtype.runOnGraph (graph);
3645
46+ // replace aten::to.dtype_layout with aten::to.dtype
47+ torch::jit::SubgraphRewriter map_aten_dtype_layout;
48+ map_aten_dtype_layout.RegisterRewritePattern (to_dtype_layout_pattern, to_dtype_multi_input_pattern);
49+ map_aten_dtype_layout.runOnGraph (graph);
50+
3751 // replace aten::type_as with aten::to.other
3852 torch::jit::SubgraphRewriter map_aten_type_as_to_other;
3953 map_aten_type_as_to_other.RegisterRewritePattern (to_type_as_pattern, to_other_pattern);
0 commit comments