@@ -16,6 +16,15 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
16
16
graph(%x, %device, %dtype, %nb, %copy, %format):
17
17
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
18
18
return (%out))IR" ;
19
+ std::string to_dtype_layout_pattern = R"IR(
20
+ graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
21
+ %out : Tensor = aten::to.dtype_layout(%x, %device, %dtype, %layout, %nb, %copy, %format, %other)
22
+ return (%out))IR" ;
23
+
24
+ std::string to_dtype_multi_input_pattern = R"IR(
25
+ graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
26
+ %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
27
+ return (%out))IR" ;
19
28
20
29
std::string to_type_as_pattern = R"IR(
21
30
graph(%input, %other):
@@ -34,6 +43,11 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
34
43
map_aten_device_to_dtype.RegisterRewritePattern (to_device_pattern, to_dtype_pattern);
35
44
map_aten_device_to_dtype.runOnGraph (graph);
36
45
46
+ // replace aten::to.dtype_layout with aten::to.dtype
47
+ torch::jit::SubgraphRewriter map_aten_dtype_layout;
48
+ map_aten_dtype_layout.RegisterRewritePattern (to_dtype_layout_pattern, to_dtype_multi_input_pattern);
49
+ map_aten_dtype_layout.runOnGraph (graph);
50
+
37
51
// replace aten::type_as with aten::to.other
38
52
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
39
53
map_aten_type_as_to_other.RegisterRewritePattern (to_type_as_pattern, to_other_pattern);
0 commit comments