Skip to content

Commit b82f743

Browse files
angelayiguilhermeleobas
authored andcommitted
[export] Support torch.sym* ops (pytorch#115854)
Fixes pytorch#108830 and pytorch/executorch#1379 (comment) Pull Request resolved: pytorch#115854 Approved by: https://github.com/zhxchen17
1 parent f28a6fd commit b82f743

File tree

5 files changed

+62
-11
lines changed

5 files changed

+62
-11
lines changed

test/export/test_export.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,19 @@ def forward(self, x, y):
19341934
self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0])
19351935
self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0])
19361936

1937+
@testing.expectedFailureNonStrict
1938+
def test_sym_sqrt(self):
1939+
import math
1940+
class M(torch.nn.Module):
1941+
def forward(self, x):
1942+
return x / torch.sym_sqrt(x.shape[0])
1943+
1944+
ep = export(M(), (torch.ones(16, 4),), dynamic_shapes={'x': {0: Dim("dim")}})
1945+
_ExportPassBase()(ep.graph_module)
1946+
FileCheck().check_count(
1947+
"torch.sym_sqrt", 1, exactly=True
1948+
).run(ep.graph_module.code)
1949+
19371950
def test_check_specialized_int(self):
19381951
class SingleOp(torch.nn.Module):
19391952
def __init__(self):

test/export/test_serialize.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
import torch._dynamo as torchdynamo
14-
from torch.export import export, save, load
14+
from torch.export import export, save, load, Dim
1515
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
1616
from torch._export.db.examples import all_examples
1717
from torch._export.serde.serialize import (
@@ -471,6 +471,15 @@ def forward(self, x, y, z):
471471
inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4))
472472
self.check_graph(MyModule(), inputs)
473473

474+
def test_sym_ite(self):
475+
def f(x):
476+
b = x.shape[0] == 5
477+
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
478+
return ret
479+
480+
dynamic_shapes = {'x': {0: Dim("dim0"), 1: Dim("dim1")}}
481+
self.check_graph(f, (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes)
482+
474483
@parametrize(
475484
"name,case",
476485
get_filtered_export_db_tests(),

torch/_export/pass_base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import traceback
33
import typing
44
from contextlib import nullcontext
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
66

77
import torch
88
from functorch.experimental.control_flow import _unstack_pytree
@@ -29,6 +29,16 @@
2929
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
3030

3131

32+
_TORCH_SYM_OPS: Set[Callable] = {
33+
torch.sym_int,
34+
torch.sym_ite,
35+
torch.sym_max,
36+
torch.sym_min,
37+
torch.sym_not,
38+
torch.sym_sqrt,
39+
}
40+
41+
3242
class ExportPassBaseError(RuntimeError):
3343
pass
3444

@@ -182,6 +192,9 @@ def call_function(
182192
elif getattr(target, "__module__", None) in {"_operator", "math"}:
183193
assert callable(target)
184194
return self.callback.call_sym(target, args, meta)
195+
elif target in _TORCH_SYM_OPS:
196+
assert callable(target)
197+
return self.callback.call_sym(target, args, meta)
185198
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
186199
return self.callback.call_operator(
187200
target,

torch/_export/serde/serialize.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import dataclasses
44
import heapq
5+
import inspect
56
import io
67
import json
78
import logging
@@ -152,6 +153,12 @@ def _reverse_map(d: Dict[Any, Enum]):
152153
operator.sub,
153154
operator.floordiv,
154155
operator.mod,
156+
torch.sym_sqrt,
157+
torch.sym_int,
158+
torch.sym_ite,
159+
torch.sym_max,
160+
torch.sym_min,
161+
torch.sym_sqrt,
155162
}
156163

157164

@@ -162,6 +169,7 @@ def _reverse_map(d: Dict[Any, Enum]):
162169
operator.ge,
163170
operator.lt,
164171
operator.gt,
172+
torch.sym_not,
165173
}
166174

167175

@@ -369,7 +377,7 @@ def handle_call_function(self, node: torch.fx.Node):
369377
meta_val = node.meta["val"]
370378
ex_node = Node(
371379
target=self.serialize_operator(node.target),
372-
inputs=self.serialize_sym_op_inputs(node.args),
380+
inputs=self.serialize_sym_op_inputs(node.target, node.args),
373381
outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))],
374382
metadata=self.serialize_metadata(node),
375383
)
@@ -378,7 +386,7 @@ def handle_call_function(self, node: torch.fx.Node):
378386
meta_val = node.meta["val"]
379387
ex_node = Node(
380388
target=self.serialize_operator(node.target),
381-
inputs=self.serialize_sym_op_inputs(node.args),
389+
inputs=self.serialize_sym_op_inputs(node.target, node.args),
382390
outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))],
383391
metadata=self.serialize_metadata(node),
384392
)
@@ -453,9 +461,9 @@ def export_nn_module_stack(val):
453461

454462
return ret
455463

456-
def serialize_sym_op_inputs(self, args) -> List[NamedArgument]:
464+
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
457465
serialized_args = []
458-
args_names = ["a", "b"]
466+
args_names = inspect.signature(op).parameters.keys()
459467
for args_name, arg in zip(args_names, args):
460468
serialized_args.append(
461469
NamedArgument(name=args_name, arg=self.serialize_input(arg))
@@ -1015,9 +1023,9 @@ def deserialize_operator(self, serialized_target: str):
10151023
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
10161024
module = operator
10171025
serialized_target_names = serialized_target.split(".")[1:]
1018-
elif serialized_target.startswith("torch.ops"):
1019-
module = torch.ops
1020-
serialized_target_names = serialized_target.split(".")[2:]
1026+
elif serialized_target.startswith("torch"):
1027+
module = torch # type: ignore[misc]
1028+
serialized_target_names = serialized_target.split(".")[1:]
10211029
else: # TODO(zhxchen17) Don't catch all here.
10221030
return serialized_target
10231031

@@ -1150,7 +1158,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
11501158
return self.graph
11511159

11521160
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
1153-
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
1161+
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
11541162
name = serialized_node.outputs[0].value.as_name
11551163
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
11561164

torch/_export/verifier.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,15 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]:
158158
return ret
159159

160160
# TODO Remove this allowlist.
161-
_allowed_torch_functions = (torch.autograd.grad_mode.set_grad_enabled,)
161+
_allowed_torch_functions = (
162+
torch.autograd.grad_mode.set_grad_enabled,
163+
torch.sym_int,
164+
torch.sym_ite,
165+
torch.sym_max,
166+
torch.sym_min,
167+
torch.sym_not,
168+
torch.sym_sqrt,
169+
)
162170

163171
if not isinstance(op, _allowed_op_types()):
164172
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:

0 commit comments

Comments
 (0)