@@ -8,22 +8,14 @@ namespace lowering {
8
8
namespace passes {
9
9
10
10
void ReduceToOperation (std::shared_ptr<torch::jit::Graph>& graph) {
11
- std::string to_device_pattern = R"IR(
12
- graph(%x, %device, %dtype, %nb, %copy, %format):
13
- %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
14
- return (%out))IR" ;
15
- std::string to_dtype_pattern = R"IR(
16
- graph(%x, %device, %dtype, %nb, %copy, %format):
17
- %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
18
- return (%out))IR" ;
19
11
std::string to_dtype_layout_pattern = R"IR(
20
- graph(%x, %device , %dtype , %layout , %pm, %nb, %copy, %format):
21
- %out : Tensor = aten::to(%x, %device , %dtype , %layout , %pm, %nb, %copy, %format)
12
+ graph(%x, %dtype , %layout , %device , %pm, %nb, %copy, %format):
13
+ %out : Tensor = aten::to(%x, %dtype , %layout , %device , %pm, %nb, %copy, %format)
22
14
return (%out))IR" ;
23
15
24
16
std::string to_dtype_multi_input_pattern = R"IR(
25
- graph(%x, %device , %dtype , %layout , %pm, %nb, %copy, %format):
26
- %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
17
+ graph(%x, %dtype , %layout , %device , %pm, %nb, %copy, %format):
18
+ %out : Tensor = aten::to(%x, %device, % dtype, %nb, %copy, %format)
27
19
return (%out))IR" ;
28
20
29
21
std::string to_type_as_pattern = R"IR(
@@ -38,11 +30,6 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
38
30
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
39
31
return (%out))IR" ;
40
32
41
- // replace aten::to.device with aten::to.dtype
42
- torch::jit::SubgraphRewriter map_aten_device_to_dtype;
43
- map_aten_device_to_dtype.RegisterRewritePattern (to_device_pattern, to_dtype_pattern);
44
- map_aten_device_to_dtype.runOnGraph (graph);
45
-
46
33
// replace aten::to.dtype_layout with aten::to.dtype
47
34
torch::jit::SubgraphRewriter map_aten_dtype_layout;
48
35
map_aten_dtype_layout.RegisterRewritePattern (to_dtype_layout_pattern, to_dtype_multi_input_pattern);
0 commit comments