Skip to content

Commit a6ec5dc

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Make edge dialect program a first class citizen (#508)
Summary: Pull Request resolved: #508 As building out the Inspector APIs, we realized that we want to make edge dialect program a first class citizen of ETRecord, and also require executorch program when constructing an ETRecord (used to be optional). Reviewed By: tarun292 Differential Revision: D49436040 fbshipit-source-id: 59986e75d3da141502066d152925119998426d4e
1 parent 30eb47f commit a6ec5dc

File tree

7 files changed

+86
-39
lines changed

7 files changed

+86
-39
lines changed

docs/website/docs/sdk/01_generating_etrecord.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ There are two important API's users must be aware of when dealing with ETrecord:
1313
```python
1414
generate_etrecord(
1515
etrecord_path: str,
16-
program: Optional[Union[ExecutorchProgram, MultiMethodExecutorchProgram]] = None,
16+
edge_dialect_program: ExirExportedProgram
17+
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
1718
export_modules: Optional[
1819
Dict[
1920
str, Union[MultiMethodExirExportedProgram, ExirExportedProgram]
@@ -23,13 +24,15 @@ generate_etrecord(
2324
```
2425

2526
Generates an ETRecord from the given objects and saves it to the given path.
26-
The objects that will be serialized to an ETRecord are all the graph modules present in the export_modules dict and also the graph module present in the program object, which is the closest graph module representation of what is eventually run on the device.
27+
The objects that will be serialized to an ETRecord are all the graph modules present in the export_modules dict, the graph module present in the edge dialect program object,
28+
and also the graph module present in the executorch program object, which is the closest graph module representation of what is eventually run on the device.
2729

2830
In addition to all the graph modules we also serialize the program buffer which the users can provide to the ExecuTorch runtime to run the model.
2931

3032
#### Parameters:
3133
- `etrecord_path` : Path to where the ETRecord file will be saved to.
32-
- `program`: ExecutorchProgram or MultiMethodExecutorchProgram for this model returned by the call to to_executorch()
34+
- `edge_dialect_program`: ExirExportedProgram for this model returned by the call to to_edge()
35+
- `executorch_program`: ExecutorchProgram or MultiMethodExecutorchProgram for this model returned by the call to to_executorch()
3336
- `export_modules`: Dictionary of graph modules with the key being the user provided name and the value is the corresponding exported module. The exported graph modules can be either the output of capture() or to_edge().
3437

3538
#### Returns:

sdk/etdb/_inspector_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,24 @@
1616
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
1717
from executorch.sdk.etrecord import ETRecord, parse_etrecord
1818

19+
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
20+
1921

2022
def gen_graphs_from_etrecord(
2123
etrecord: ETRecord,
2224
) -> Mapping[str, OperatorGraphWithStats]:
23-
if etrecord.graph_map is None:
24-
return {}
25-
return {
26-
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
27-
for name, exported_program in etrecord.graph_map.items()
28-
}
25+
op_graph_map = {}
26+
if etrecord.graph_map is not None:
27+
op_graph_map = {
28+
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
29+
for name, exported_program in etrecord.graph_map.items()
30+
}
31+
if etrecord.edge_dialect_program is not None:
32+
op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph(
33+
etrecord.edge_dialect_program.graph_module
34+
)
35+
36+
return op_graph_map
2937

3038

3139
# TODO: use anonymous function to avoid passing the dict around

sdk/etdb/inspector.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
2929
from executorch.sdk.etdb._inspector_utils import (
3030
create_debug_handle_to_op_node_mapping,
31+
EDGE_DIALECT_GRAPH_KEY,
3132
gen_etdump_object,
3233
gen_etrecord_object,
3334
gen_graphs_from_etrecord,
@@ -71,7 +72,7 @@ def _gen_from_event(event: ProfileEvent) -> "ProfileEventSignature":
7172
The Signature will convert these back to the intended None value
7273
"""
7374
return ProfileEventSignature(
74-
event.name,
75+
event.name or "",
7576
event.instruction_id if event.instruction_id != -1 else None,
7677
event.delegate_debug_id_int if event.delegate_debug_id_int != -1 else None,
7778
event.delegate_debug_id_str if event.delegate_debug_id_str != "" else None,
@@ -346,9 +347,6 @@ def _gen_resolve_debug_handles(
346347
)
347348

348349

349-
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
350-
351-
352350
class Inspector:
353351
"""
354352
APIs for examining model architecture and performance stats.
@@ -455,13 +453,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
455453
# TODO: implement
456454
pass
457455

458-
def get_exported_program(
459-
self, graph: Optional[str] = EDGE_DIALECT_GRAPH_KEY
460-
) -> ExportedProgram:
456+
def get_exported_program(self, graph: Optional[str] = None) -> ExportedProgram:
461457
"""
462458
Access helper for ETRecord, defaults to returning Edge Dialect Program
463459
464460
Args:
465-
graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
461+
graph: Name of the graph to access. If None, returns the Edge Dialect Program.
466462
"""
463+
if graph is None:
464+
return self._etrecord.edge_dialect_program
467465
return self._etrecord.graph_map.get(graph)

sdk/etdb/tests/inspector_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def test_inspector_get_exported_program(self):
197197
EventBlock, "_gen_from_etdump"
198198
), patch.object(
199199
inspector, "gen_graphs_from_etrecord"
200+
), patch.object(
201+
inspector, "create_debug_handle_to_op_node_mapping"
200202
):
201203
# Call the constructor of Inspector
202204
inspector_instance = Inspector(
@@ -209,10 +211,10 @@ def test_inspector_get_exported_program(self):
209211
with tempfile.TemporaryDirectory() as tmpdirname:
210212
generate_etrecord(
211213
tmpdirname + "/etrecord.bin",
214+
edge_output,
212215
et_output,
213216
{
214217
"aten_dialect_output": captured_output,
215-
"edge_dialect_output": edge_output,
216218
},
217219
)
218220

sdk/etdb/tests/inspector_utils_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.sdk.etdb._inspector_utils import (
1818
create_debug_handle_to_op_node_mapping,
19+
EDGE_DIALECT_GRAPH_KEY,
1920
gen_graphs_from_etrecord,
2021
)
2122
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
@@ -29,10 +30,10 @@ def test_gen_graphs_from_etrecord(self):
2930
with tempfile.TemporaryDirectory() as tmpdirname:
3031
generate_etrecord(
3132
tmpdirname + "/etrecord.bin",
33+
edge_output,
3234
et_output,
3335
{
3436
"aten_dialect_output": captured_output,
35-
"edge_dialect_output": edge_output,
3637
},
3738
)
3839

@@ -42,17 +43,15 @@ def test_gen_graphs_from_etrecord(self):
4243

4344
self.assertTrue("aten_dialect_output/forward" in graphs)
4445
self.assertTrue("et_dialect_graph_module/forward" in graphs)
45-
self.assertTrue("edge_dialect_output/forward" in graphs)
46+
self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs)
4647

4748
self.assertTrue(
4849
isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph)
4950
)
5051
self.assertTrue(
5152
isinstance(graphs["et_dialect_graph_module/forward"], FXOperatorGraph)
5253
)
53-
self.assertTrue(
54-
isinstance(graphs["edge_dialect_output/forward"], FXOperatorGraph)
55-
)
54+
self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph))
5655

5756
def test_create_debug_handle_to_op_node_mapping(self):
5857
debug_handle_to_op_node_map = {}

sdk/etrecord/_etrecord.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
class ETRecordReservedFileNames(str, Enum):
2525
ETRECORD_IDENTIFIER = "ETRECORD_V0"
2626
PROGRAM_BUFFER = "program_buffer"
27+
EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
2728
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
2829
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
2930
DELEGATE_MAP_NAME = "delegate_map"
3031

3132

3233
@dataclass
3334
class ETRecord:
35+
edge_dialect_program: Optional[ExportedProgram] = None
3436
graph_map: Optional[Dict[str, ExportedProgram]] = None
3537
program_buffer: Optional[bytes] = None
3638
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
@@ -76,7 +78,7 @@ def _handle_export_module(
7678
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
7779

7880

79-
def _handle_program(
81+
def _handle_executorch_program(
8082
etrecord_zip: ZipFile,
8183
program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
8284
) -> None:
@@ -111,9 +113,25 @@ def _handle_program(
111113
)
112114

113115

116+
def _handle_edge_dialect_exported_program(
117+
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
118+
) -> None:
119+
serialized_ep, serialized_state_dict = serialize(edge_dialect_exported_program)
120+
121+
etrecord_zip.writestr(
122+
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM,
123+
serialized_ep,
124+
)
125+
etrecord_zip.writestr(
126+
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict",
127+
serialized_state_dict,
128+
)
129+
130+
114131
def generate_etrecord(
115132
etrecord_path: str,
116-
program: Optional[Union[ExecutorchProgram, MultiMethodExecutorchProgram]] = None,
133+
edge_dialect_program: ExirExportedProgram,
134+
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
117135
export_modules: Optional[
118136
Dict[
119137
str,
@@ -132,7 +150,8 @@ def generate_etrecord(
132150
133151
Args:
134152
etrecord_path: Path to where the ETRecord file will be saved to.
135-
program: ExecutorchProgram or MultiMethodExecutorchProgram for this model returned by the
153+
edge_dialect_program: ExirExportedProgram for this model returned by the call to to_edge()
154+
executorch_program: ExecutorchProgram or MultiMethodExecutorchProgram for this model returned by the
136155
call to to_executorch()
137156
export_modules: Dictionary of graph modules with the key being the user provided name and the
138157
value is the corresponding exported module. The exported graph modules can be either the
@@ -158,18 +177,22 @@ def generate_etrecord(
158177
)
159178
_handle_export_module(etrecord_zip, export_module, module_name)
160179

161-
if program is not None:
162-
_handle_program(etrecord_zip, program)
180+
_handle_executorch_program(etrecord_zip, executorch_program)
163181

164-
etrecord_zip.writestr(
165-
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
166-
json.dumps(program.debug_handle_map),
167-
)
182+
_handle_edge_dialect_exported_program(
183+
etrecord_zip,
184+
edge_dialect_program.exported_program,
185+
)
168186

169-
etrecord_zip.writestr(
170-
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
171-
json.dumps(program.delegate_map),
172-
)
187+
etrecord_zip.writestr(
188+
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
189+
json.dumps(executorch_program.debug_handle_map),
190+
)
191+
192+
etrecord_zip.writestr(
193+
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
194+
json.dumps(executorch_program.delegate_map),
195+
)
173196

174197

175198
def parse_etrecord(etrecord_path: str) -> ETRecord:
@@ -202,6 +225,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord:
202225
debug_handle_map = None
203226
delegate_map = None
204227
program_buffer = None
228+
edge_dialect_program = None
205229

206230
serialized_exported_program_files = set()
207231
serialized_state_dict_files = set()
@@ -218,6 +242,13 @@ def parse_etrecord(etrecord_path: str) -> ETRecord:
218242
program_buffer = etrecord_zip.read(ETRecordReservedFileNames.PROGRAM_BUFFER)
219243
elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER:
220244
continue
245+
elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM:
246+
edge_dialect_program = deserialize(
247+
etrecord_zip.read(
248+
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
249+
),
250+
etrecord_zip.read(f"{entry}_state_dict"),
251+
)
221252
else:
222253
if entry.endswith("state_dict"):
223254
serialized_state_dict_files.add(entry)
@@ -235,6 +266,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord:
235266
)
236267

237268
return ETRecord(
269+
edge_dialect_program=edge_dialect_program,
238270
graph_map=graph_map,
239271
program_buffer=program_buffer,
240272
_debug_handle_map=debug_handle_map,

sdk/etrecord/tests/etrecord_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def test_etrecord_generation(self):
5151
with tempfile.TemporaryDirectory() as tmpdirname:
5252
generate_etrecord(
5353
tmpdirname + "/etrecord.bin",
54+
edge_output,
5455
et_output,
5556
{
5657
"aten_dialect_output": captured_output,
57-
"edge_dialect_output": edge_output,
5858
},
5959
)
6060

@@ -64,7 +64,7 @@ def test_etrecord_generation(self):
6464
captured_output.exported_program.graph_module,
6565
)
6666
self.check_graph_closeness(
67-
etrecord.graph_map["edge_dialect_output/forward"],
67+
etrecord.edge_dialect_program,
6868
edge_output.exported_program.graph_module,
6969
)
7070
self.check_graph_closeness(
@@ -82,7 +82,10 @@ def test_etrecord_invalid_input(self):
8282
with tempfile.TemporaryDirectory() as tmpdirname:
8383
with self.assertRaises(RuntimeError):
8484
generate_etrecord(
85-
tmpdirname + "/etrecord.bin", {"fail_test_case": et_output}
85+
tmpdirname + "/etrecord.bin",
86+
edge_output,
87+
et_output,
88+
{"fail_test_case": et_output},
8689
)
8790

8891
def test_etrecord_reserved_name(self):
@@ -92,5 +95,7 @@ def test_etrecord_reserved_name(self):
9295
with self.assertRaises(RuntimeError):
9396
generate_etrecord(
9497
tmpdirname + "/etrecord.bin",
98+
edge_output,
99+
et_output,
95100
{reserved_name: captured_output.exported_program.graph_module},
96101
)

0 commit comments

Comments
 (0)