Skip to content

Commit a55d6b6

Browse files
angelayifacebook-github-bot
authored andcommitted
Serialize memory.alloc
Summary: `memory.alloc` takes in a tuple of `(size, dtype)` or a singleton list containing this tuple, so when we serialize we can just serialize this into a string like `"size;dtype"` Reviewed By: zhxchen17 Differential Revision: D47311041 fbshipit-source-id: 59621f20831620e3291e75f0c1fb720f556dffd0
1 parent e379915 commit a55d6b6

File tree

4 files changed

+213
-16
lines changed

4 files changed

+213
-16
lines changed

exir/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def graph_module(self) -> torch.fx.GraphModule:
342342
def dump_graph_module(self) -> torch.fx.GraphModule:
343343
return self.graph_module
344344

345+
def dump_exported_program(self) -> ExirExportedProgram:
346+
return self.exported_program
347+
345348

346349
# TODO(ycao): Move Executorch dialect to its own file
347350
@compatibility(is_backward_compatible=False)

exir/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.utils import _pytree as pytree
1010
from typing_extensions import TypeAlias
1111

12-
TensorAllocSpec: TypeAlias = Tuple[List[int], torch.dtype]
12+
TensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype]
1313
AllocSpec: TypeAlias = Union[
1414
TensorAllocSpec,
1515
List[TensorAllocSpec],

exir/serde/serialize.py

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,119 @@
33
import copy
44
import dataclasses
55
import json
6-
from typing import Any, Dict, Optional, Tuple
6+
import logging
7+
import operator
8+
from typing import Any, Callable, Dict, List, Optional, Tuple
79

810
import executorch.exir as exir
11+
import executorch.exir.memory as memory
912
import torch
1013
import torch._export.exported_program as ep
1114
import torch._export.serde.schema as schema
1215
import torch._export.serde.serialize as export_serialize
1316
from torch.fx.experimental import symbolic_shapes
1417

1518

19+
log: logging.Logger = logging.getLogger(__name__)
20+
21+
1622
class GraphModuleSerializer(export_serialize.GraphModuleSerializer):
1723
def __init__(
1824
self, graph_signature: ep.ExportGraphSignature, call_spec: ep.CallSpec
1925
) -> None:
2026
super().__init__(graph_signature, call_spec)
2127
self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982)
2228

29+
def handle_call_function(self, node: torch.fx.Node) -> None:
30+
assert node.op == "call_function"
31+
32+
if node.target is memory.alloc:
33+
ex_node = schema.Node(
34+
target="memory.alloc",
35+
inputs=self.serialize_alloc_inputs(node.args),
36+
outputs=self.serialize_arbitrary_outputs(node),
37+
metadata=self.serialize_metadata(node),
38+
)
39+
self.graph_state.nodes.append(ex_node)
40+
return
41+
42+
super().handle_call_function(node)
43+
44+
def serialize_alloc_inputs(
45+
self, inputs # pyre-ignore
46+
) -> List[schema.NamedArgument]:
47+
"""
48+
Serialize the inputs to the memory.alloc function. Since there's no
49+
specific spec, we jut serialize the inputs with a dummy name.
50+
We serialize the AllocSpec into a string "size;dtype"
51+
"""
52+
assert len(inputs) == 1
53+
54+
def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument:
55+
return schema.Argument.create(
56+
as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}"
57+
)
58+
59+
if isinstance(inputs[0], list):
60+
# Singleton list
61+
assert len(inputs[0]) == 1
62+
return [
63+
schema.NamedArgument(
64+
name="alloc_list", arg=serialize_alloc_spec(inputs[0][0])
65+
)
66+
]
67+
else:
68+
# Single value
69+
return [
70+
schema.NamedArgument(
71+
name="alloc_arg", arg=serialize_alloc_spec(inputs[0])
72+
)
73+
]
74+
75+
def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
76+
meta_val = node.meta["val"]
77+
78+
# Check single value return
79+
if isinstance(meta_val, torch.Tensor):
80+
return [
81+
schema.Argument.create(
82+
as_tensor=self.serialize_tensor_output(node.name, meta_val)
83+
)
84+
]
85+
86+
# There are a two possibilities at this point:
87+
# - This operator returns a list of Tensors.
88+
# - This operator returns multiple Tensors.
89+
#
90+
# Either way, start by gathering a list of TensorArguments with the correct names.
91+
# For consistent naming with FX, consult the downstream `getitem` node and
92+
# make sure our outputs have the same name.
93+
idx_to_name = {}
94+
for user in node.users:
95+
if user.target is not operator.getitem:
96+
continue
97+
idx_to_name[user.args[1]] = user.name
98+
99+
for idx, _ in enumerate(meta_val):
100+
# FX does not emit a getitem node for any outputs that are unused.
101+
# However, we need a name for them so that the number of outputs will
102+
# correctly match the schema. Just assign a dummy name.
103+
if idx not in idx_to_name:
104+
idx_to_name[idx] = f"{node.name}_unused_{idx}"
105+
106+
arg_list = []
107+
for i, element_meta_val in enumerate(meta_val):
108+
arg_list.append(
109+
self.serialize_tensor_output(idx_to_name[i], element_meta_val)
110+
)
111+
112+
if len(meta_val) == 1:
113+
# The operator returns a list of tensors
114+
return [schema.Argument.create(as_tensors=arg_list)]
115+
else:
116+
# The operator returns multiple tensors
117+
return [schema.Argument.create(as_tensor=arg) for arg in arg_list]
118+
23119
# pyre-ignore
24120
def serialize_input(self, arg) -> schema.Argument:
25121
if isinstance(arg, torch.fx.Node):
@@ -71,7 +167,75 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
71167
super().__init__()
72168
self.state_dict: Dict[str, Any] = state_dict # TODO(T157676982)
73169

74-
# TODO(angelayi): implement for delegation
170+
# pyre-ignore
171+
def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
172+
if target == "memory.alloc":
173+
args = self.deserialize_alloc_inputs(serialized_node.inputs)
174+
fx_node = self.graph.create_node(
175+
"call_function", memory.alloc, args, {}, "alloc"
176+
)
177+
178+
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
179+
180+
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
181+
return
182+
183+
elif isinstance(target, str):
184+
# Create a dummy fake op if the target does not exist
185+
# because we cannot create a call_function node w/o a
186+
# callable target
187+
log.warning(
188+
f"Could not find operator {target}. Returning fake operator."
189+
) # noqa: G004
190+
191+
# pyre-ignore
192+
def fake_op(x):
193+
raise NotImplementedError("Fake op is not meant to be run.")
194+
195+
fake_op.__name__ = target
196+
target = fake_op
197+
return
198+
199+
super().deserialize_node(serialized_node, target)
200+
201+
# pyre-ignore
202+
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
203+
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
204+
serialized_alloc_spec_elems = serialized_alloc_spec.split(";")
205+
assert len(serialized_alloc_spec_elems) == 2
206+
serialized_size_elems = (
207+
serialized_alloc_spec_elems[0].strip("()").split(",")
208+
)
209+
210+
size = tuple(int(x) for x in serialized_size_elems if x != "")
211+
dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[
212+
int(serialized_alloc_spec_elems[1])
213+
]
214+
return (size, dtype)
215+
216+
assert serialized_inputs[0].arg.type == "as_string"
217+
218+
# Single value
219+
if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg":
220+
res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),)
221+
return res
222+
223+
# Singleton list value
224+
assert len(serialized_inputs) == 1
225+
alloc_specs = [deserialize_alloc_spec(serialized_inputs[0].arg.value)]
226+
return (alloc_specs,)
227+
228+
def deserialize_arbitrary_outputs(
229+
self, serialized_node: schema.Node, fx_node: torch.fx.Node
230+
) -> None:
231+
# Single tensor return
232+
if (
233+
len(serialized_node.outputs) == 1
234+
and serialized_node.outputs[0].type == "as_tensor"
235+
):
236+
return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
237+
238+
self.deserialize_multiple_outputs(serialized_node, fx_node)
75239

76240
# pyre-ignore
77241
def deserialize_input(self, inp: schema.Argument) -> Any:

exir/tests/test_serde.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77

88
import torch
99
from executorch.exir.serde.serialize import deserialize, serialize
10+
from torch._export.exported_program import ExportedProgram as TorchExportedProgram
1011
from torch.utils import _pytree as pytree
1112

1213

1314
# Tests for serializing to json and back
1415
class TestSerde(unittest.TestCase):
1516
def check_ep(
1617
self,
17-
ep1: exir.ExportedProgram,
18-
ep2: exir.ExportedProgram,
18+
ep1: TorchExportedProgram,
19+
ep2: TorchExportedProgram,
1920
inputs: Tuple[exir.Value, ...],
2021
) -> None:
2122
"""
@@ -30,6 +31,20 @@ def check_ep(
3031
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
3132
self.assertTrue(torch.allclose(orig, loaded))
3233

34+
# pyre-ignore
35+
def check_serde(self, m, inputs) -> None:
36+
aten = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True))
37+
aten_new = deserialize(*serialize(aten))
38+
self.check_ep(aten, aten_new, inputs)
39+
40+
edge = aten.to_edge()
41+
edge_new = deserialize(*serialize(edge))
42+
self.check_ep(edge, edge_new, inputs)
43+
44+
executorch = edge.to_executorch().dump_exported_program()
45+
executorch_new = deserialize(*serialize(executorch))
46+
self.check_ep(executorch, executorch_new, inputs)
47+
3348
def test_basic(self) -> None:
3449
class MyModule(torch.nn.Module):
3550
def __init__(self):
@@ -42,20 +57,35 @@ def forward(self, x):
4257
return x, x.clone()
4358

4459
inputs = (torch.ones([512], requires_grad=True),)
45-
aten = exir.capture(MyModule(), inputs, exir.CaptureConfig(pt2_mode=True))
46-
aten_new = deserialize(*serialize(aten))
47-
self.check_ep(aten, aten_new, inputs)
60+
self.check_serde(MyModule(), inputs)
4861

49-
def test_getattr(self) -> None:
50-
class MyModule(torch.nn.Module):
62+
def test_to_out_variant_singleon_tensor_list(self) -> None:
63+
class MyModel(torch.nn.Module):
5164
def __init__(self):
5265
super().__init__()
53-
self.linear = torch.nn.Linear(512, 512)
5466

5567
def forward(self, x):
56-
return self.linear(x)
68+
return torch.split(x, 10)
5769

58-
inputs = (torch.ones(512, 512, requires_grad=True),)
59-
aten = exir.capture(MyModule(), inputs, exir.CaptureConfig(pt2_mode=True))
60-
aten_new = deserialize(*serialize(aten))
61-
self.check_ep(aten, aten_new, inputs)
70+
def get_random_inputs(self):
71+
return (torch.randn(10),)
72+
73+
model = MyModel()
74+
inputs = model.get_random_inputs()
75+
self.check_serde(model, inputs)
76+
77+
def test_to_out_variant_multiple_out(self) -> None:
78+
class MyModel(torch.nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
82+
def forward(self, x):
83+
values, indices = torch.topk(x, 5)
84+
return (values, indices)
85+
86+
def get_random_inputs(self):
87+
return (torch.randn(10),)
88+
89+
model = MyModel()
90+
inputs = model.get_random_inputs()
91+
self.check_serde(model, inputs)

0 commit comments

Comments
 (0)