Skip to content

Commit 8accb86

Browse files
cccclaifacebook-github-bot
authored andcommitted
helper function to print node info (#512)
Summary: Pull Request resolved: #512 Add a helper function to print the node info given the graph and a node ``` print(inspect_node(graph, node)) _param_constant1 error_msg: Here is the node in the graph module: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] --> %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) return [aten_gelu_default] This node _param_constant1 has metadata of: The node stacktrace: Traceback (most recent call last): File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward return self.test_model(x) File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward a = self.conv1(x) ``` Reviewed By: tarun292 Differential Revision: D49715284 fbshipit-source-id: 0e901c7c5b8be074261c5f4fc4f8dc72d7c23db5
1 parent 645594e commit 8accb86

File tree

7 files changed

+181
-85
lines changed

7 files changed

+181
-85
lines changed

exir/common.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -92,43 +92,6 @@ def format_schema_name(schema: torch._C.FunctionSchema) -> str:
9292
return schema.name
9393

9494

95-
def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str:
96-
"""
97-
Insert a cursor at the node location in the fx.Graph.
98-
e.g:
99-
# graph():
100-
# %x : [#users=1] = placeholder[target=x]
101-
# %param : [#users=1] = get_attr[target=param]
102-
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
103-
# --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
104-
# %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
105-
# return clamp
106-
107-
This is mostly used for error reporting
108-
"""
109-
110-
new_graph = copy.deepcopy(graph)
111-
112-
found_at = -1
113-
for ix, node in enumerate(graph.nodes):
114-
if node == finding_node:
115-
found_at = ix
116-
117-
# This is heavily based on __str__ method of fx.Graph
118-
def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str:
119-
s = "graph():"
120-
for ix, node in enumerate(graph.nodes):
121-
node_str = node.format_node()
122-
if node_str:
123-
if ix != offending_node_idx:
124-
s += "\n " + node_str
125-
else:
126-
s += "\n--> " + node_str
127-
return s
128-
129-
return _format_graph(new_graph, found_at)
130-
131-
13295
@contextmanager
13396
def override_logger(
13497
newLevel: int = logging.DEBUG,

exir/emit/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ python_library(
1919
deps = [
2020
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
2121
"//caffe2:torch",
22-
"//executorch/exir:common",
2322
"//executorch/exir:delegate",
2423
"//executorch/exir:error",
2524
"//executorch/exir:memory",

exir/emit/_emitter.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
# pyre-strict
3131
import ctypes
3232
import operator
33-
import re
3433
import typing
3534
from dataclasses import dataclass, field
3635
from typing import Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
@@ -39,14 +38,13 @@
3938
import executorch.extension.pytree as ex_pytree
4039
import torch
4140
import torch.fx
42-
from executorch.exir.common import add_cursor_to_graph
4341
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
4442
from executorch.exir.dialects.backend._ops import BackendOpOverload
4543
from executorch.exir.dialects.edge._ops import EdgeOpOverload
4644
from executorch.exir.error import ExportError, ExportErrorType, InternalError
4745
from executorch.exir.operator.convert import is_out_variant
4846
from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
49-
from executorch.exir.print_program import pretty_print_stacktraces
47+
from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node
5048
from executorch.exir.schema import (
5149
BackendDelegate,
5250
BackendDelegateDataReference,
@@ -62,8 +60,6 @@
6260
DoubleList,
6361
EValue,
6462
ExecutionPlan,
65-
Frame,
66-
FrameList,
6763
FreeCall,
6864
Instruction,
6965
Int,
@@ -238,44 +234,12 @@ def __init__(
238234
int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
239235
] = {}
240236

241-
def _stacktrace_to_framelist(self, stacktrace: str) -> FrameList:
242-
"""Creates a frame list from a stacktrace string."""
243-
pattern = r'File "(.*?)", line (\d+), in (.*?)\n'
244-
matches = re.findall(pattern, stacktrace)
245-
mapped_frame_list = [
246-
Frame(
247-
filename=match[0],
248-
lineno=int(match[1]),
249-
name=match[2],
250-
context=stacktrace.split("\n")[i * 2 + 1].strip(),
251-
)
252-
for i, match in enumerate(matches)
253-
]
254-
return FrameList(mapped_frame_list)
255-
256237
def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str:
257238
"""Returns 'err_msg' with node specific information attached."""
258-
graph_str_with_cursor = add_cursor_to_graph(self.graph_module.graph, node)
259-
260-
error_msg = (
261-
f"Failed with error: {str(err_msg)}\n"
262-
f"Here is the failing node in the graph module:\n"
263-
f"{graph_str_with_cursor}\n"
264-
f"This node {self.node} has metadata of:\n"
239+
err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node(
240+
self.graph_module.graph, node
265241
)
266-
267-
# Node spec error message
268-
if hasattr(self.node.meta, "spec"):
269-
error_msg += f"The node spec:\n{self.node.meta['spec']}\n"
270-
271-
# Stacktrace error message
272-
if hasattr(self.node.meta, "stack_trace"):
273-
framelist = self._stacktrace_to_framelist(self.node.meta["stack_trace"])
274-
error_msg += (
275-
f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n"
276-
)
277-
278-
return error_msg
242+
return err_msg
279243

280244
def _internal_assert_emitter(
281245
self, pred: bool, node: torch.fx.Node, assert_msg: str
@@ -1133,7 +1097,7 @@ def _get_empty_tensor_evalue() -> EValue:
11331097
stack_trace = self.node.meta["stack_trace"]
11341098
chain_stacktrace = self.chain.stacktrace or []
11351099

1136-
chain_stacktrace.append(self._stacktrace_to_framelist(stack_trace))
1100+
chain_stacktrace.append(_stacktrace_to_framelist(stack_trace))
11371101
self._internal_assert_emitter(
11381102
len(chain_stacktrace) == len(self.chain.instructions),
11391103
self.node,

exir/print_program.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66

77
# pyre-strict
88

9+
import copy
10+
import re
911
import reprlib
1012
from dataclasses import fields
1113
from enum import IntEnum
1214
from typing import Any, List
1315

1416
import torch
15-
1617
from executorch.exir.error import ExportError, ExportErrorType, InternalError
18+
1719
from executorch.exir.schema import (
1820
Bool,
1921
BoolList,
2022
DelegateCall,
2123
Double,
2224
DoubleList,
2325
EValue,
26+
Frame,
2427
FrameList,
2528
FreeCall,
2629
Int,
@@ -302,3 +305,109 @@ def pretty_print_stacktraces(obj: FrameList) -> str:
302305
pretty += f"{frame.context} \n"
303306
pretty += "\n"
304307
return pretty
308+
309+
310+
def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str:
311+
"""
312+
Insert a cursor at the node location in the fx.Graph.
313+
e.g:
314+
# graph():
315+
# %x : [#users=1] = placeholder[target=x]
316+
# %param : [#users=1] = get_attr[target=param]
317+
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
318+
# --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
319+
# %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
320+
# return clamp
321+
322+
This is mostly used for error reporting
323+
"""
324+
325+
new_graph = copy.deepcopy(graph)
326+
327+
found_at = -1
328+
for ix, node in enumerate(graph.nodes):
329+
if node == finding_node:
330+
found_at = ix
331+
332+
# This is heavily based on __str__ method of fx.Graph
333+
def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str:
334+
s = "graph():"
335+
for ix, node in enumerate(graph.nodes):
336+
node_str = node.format_node()
337+
if node_str:
338+
if ix != offending_node_idx:
339+
s += "\n " + node_str
340+
else:
341+
s += "\n--> " + node_str
342+
return s
343+
344+
return _format_graph(new_graph, found_at)
345+
346+
347+
def _stacktrace_to_framelist(stacktrace: str) -> FrameList:
348+
"""Creates a frame list from a stacktrace string."""
349+
pattern = r'File "(.*?)", line (\d+), in (.*?)\n'
350+
matches = re.findall(pattern, stacktrace)
351+
mapped_frame_list = [
352+
Frame(
353+
filename=match[0],
354+
lineno=int(match[1]),
355+
name=match[2],
356+
context=stacktrace.split("\n")[i * 2 + 1].strip(),
357+
)
358+
for i, match in enumerate(matches)
359+
]
360+
return FrameList(mapped_frame_list)
361+
362+
363+
def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str:
364+
"""
365+
Inspect a node by highlighting the node in the graph as well as the stacktrace.
366+
367+
Args:
368+
graph: The graph containing the node
369+
node: The node to be inspected
370+
371+
Return: A string. An example output is:
372+
373+
_param_constant0 error_msg: Here is the failing node in the graph module:
374+
graph():
375+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
376+
--> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
377+
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
378+
%aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
379+
%_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
380+
%_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
381+
%aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
382+
%aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {})
383+
%_param_constant4 : [num_users=1] = get_attr[target=_param_constant4]
384+
%_param_constant5 : [num_users=1] = get_attr[target=_param_constant5]
385+
%aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
386+
%aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {})
387+
return [aten_gelu_default]
388+
This node _param_constant0 has metadata of:
389+
The node stacktrace:
390+
Traceback (most recent call last):
391+
File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward
392+
return self.test_model(x)
393+
File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl
394+
return forward_call(*args, **kwargs)
395+
File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward
396+
a = self.conv1(x)
397+
398+
"""
399+
graph_str_with_cursor = add_cursor_to_graph(graph, node)
400+
error_msg = (
401+
f"Here is the node in the graph module:\n"
402+
f"{graph_str_with_cursor}\n"
403+
f"This node {node} has metadata of:\n"
404+
)
405+
# Node spec error message
406+
if hasattr(node.meta, "spec"):
407+
error_msg += f"The node spec:\n{node.meta['spec']}\n"
408+
409+
# Stacktrace error message
410+
if "stack_trace" in node.meta:
411+
framelist = _stacktrace_to_framelist(node.meta["stack_trace"])
412+
error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n"
413+
return error_msg

exir/tests/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ python_unittest(
265265
deps = [
266266
"//caffe2:torch",
267267
"//executorch/exir:common",
268+
"//executorch/exir:print_program",
268269
],
269270
)
270271

@@ -430,3 +431,15 @@ python_unittest(
430431
"//executorch/exir/passes:lib",
431432
],
432433
)
434+
435+
python_unittest(
436+
name = "print_program",
437+
srcs = [
438+
"test_print_program.py",
439+
],
440+
deps = [
441+
"//caffe2:torch",
442+
"//executorch/exir:lib",
443+
"//executorch/exir:print_program",
444+
],
445+
)

exir/tests/test_common.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
import torch
1313
import torch.fx
1414

15-
from executorch.exir.common import (
16-
add_cursor_to_graph,
17-
extract_out_arguments,
18-
get_schema_for_operators,
19-
)
15+
from executorch.exir.common import extract_out_arguments, get_schema_for_operators
16+
from executorch.exir.print_program import add_cursor_to_graph
2017

2118

2219
class TestExirCommon(unittest.TestCase):

exir/tests/test_print_program.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch import exir
11+
from executorch.exir.print_program import inspect_node
12+
13+
14+
class TestPrintProgram(unittest.TestCase):
15+
def test_inspect_node(self) -> None:
16+
class TestModel(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.conv1 = torch.nn.Conv2d(32, 32, 1)
20+
self.conv2 = torch.nn.Conv2d(32, 32, 1)
21+
self.conv3 = torch.nn.Conv2d(32, 32, 1)
22+
self.gelu = torch.nn.GELU()
23+
24+
def forward(self, x: torch.Tensor):
25+
a = self.conv1(x)
26+
b = self.conv2(a)
27+
c = self.conv3(a + b)
28+
return self.gelu(c)
29+
30+
class WrapModule(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.test_model = TestModel()
34+
35+
def forward(self, x):
36+
return self.test_model(x)
37+
38+
warp_model = WrapModule()
39+
example_inputs = (torch.rand(1, 32, 16, 16),)
40+
41+
exir_exported_program = exir.capture(warp_model, example_inputs).to_edge()
42+
number_of_stack_trace = 0
43+
for node in exir_exported_program.exported_program.graph.nodes:
44+
node_info = inspect_node(exir_exported_program.exported_program.graph, node)
45+
self.assertRegexpMatches(node_info, r".*-->.*")
46+
if "stack_trace" in node.meta:
47+
self.assertRegexpMatches(
48+
node_info, r".*Traceback \(most recent call last\)\:.*"
49+
)
50+
number_of_stack_trace = number_of_stack_trace + 1
51+
self.assertGreaterEqual(number_of_stack_trace, 1)

0 commit comments

Comments
 (0)