|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
| 9 | +import copy |
| 10 | +import re |
9 | 11 | import reprlib |
10 | 12 | from dataclasses import fields |
11 | 13 | from enum import IntEnum |
12 | 14 | from typing import Any, List |
13 | 15 |
|
14 | 16 | import torch |
15 | | - |
16 | 17 | from executorch.exir.error import ExportError, ExportErrorType, InternalError |
| 18 | + |
17 | 19 | from executorch.exir.schema import ( |
18 | 20 | Bool, |
19 | 21 | BoolList, |
20 | 22 | DelegateCall, |
21 | 23 | Double, |
22 | 24 | DoubleList, |
23 | 25 | EValue, |
| 26 | + Frame, |
24 | 27 | FrameList, |
25 | 28 | FreeCall, |
26 | 29 | Int, |
@@ -302,3 +305,109 @@ def pretty_print_stacktraces(obj: FrameList) -> str: |
302 | 305 | pretty += f"{frame.context} \n" |
303 | 306 | pretty += "\n" |
304 | 307 | return pretty |
| 308 | + |
| 309 | + |
| 310 | +def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str: |
| 311 | + """ |
| 312 | + Insert a cursor at the node location in the fx.Graph. |
| 313 | + e.g: |
| 314 | + # graph(): |
| 315 | + # %x : [#users=1] = placeholder[target=x] |
| 316 | + # %param : [#users=1] = get_attr[target=param] |
| 317 | + # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) |
| 318 | + # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) |
| 319 | + # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) |
| 320 | + # return clamp |
| 321 | +
|
| 322 | + This is mostly used for error reporting |
| 323 | + """ |
| 324 | + |
| 325 | + new_graph = copy.deepcopy(graph) |
| 326 | + |
| 327 | + found_at = -1 |
| 328 | + for ix, node in enumerate(graph.nodes): |
| 329 | + if node == finding_node: |
| 330 | + found_at = ix |
| 331 | + |
| 332 | + # This is heavily based on __str__ method of fx.Graph |
| 333 | + def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str: |
| 334 | + s = "graph():" |
| 335 | + for ix, node in enumerate(graph.nodes): |
| 336 | + node_str = node.format_node() |
| 337 | + if node_str: |
| 338 | + if ix != offending_node_idx: |
| 339 | + s += "\n " + node_str |
| 340 | + else: |
| 341 | + s += "\n--> " + node_str |
| 342 | + return s |
| 343 | + |
| 344 | + return _format_graph(new_graph, found_at) |
| 345 | + |
| 346 | + |
| 347 | +def _stacktrace_to_framelist(stacktrace: str) -> FrameList: |
| 348 | + """Creates a frame list from a stacktrace string.""" |
| 349 | + pattern = r'File "(.*?)", line (\d+), in (.*?)\n' |
| 350 | + matches = re.findall(pattern, stacktrace) |
| 351 | + mapped_frame_list = [ |
| 352 | + Frame( |
| 353 | + filename=match[0], |
| 354 | + lineno=int(match[1]), |
| 355 | + name=match[2], |
| 356 | + context=stacktrace.split("\n")[i * 2 + 1].strip(), |
| 357 | + ) |
| 358 | + for i, match in enumerate(matches) |
| 359 | + ] |
| 360 | + return FrameList(mapped_frame_list) |
| 361 | + |
| 362 | + |
| 363 | +def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str: |
| 364 | + """ |
| 365 | + Inspect a node by highlighting the node in the graph as well as the stacktrace. |
| 366 | +
|
| 367 | + Args: |
| 368 | + graph: The graph containing the node |
| 369 | + node: The node to be inspected |
| 370 | +
|
| 371 | + Return: A string. An example output is: |
| 372 | +
|
| 373 | + _param_constant0 error_msg: Here is the failing node in the graph module: |
| 374 | + graph(): |
| 375 | + %arg0_1 : [num_users=1] = placeholder[target=arg0_1] |
| 376 | + --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] |
| 377 | + %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] |
| 378 | + %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| 379 | + %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] |
| 380 | + %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] |
| 381 | + %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| 382 | + %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) |
| 383 | + %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] |
| 384 | + %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] |
| 385 | + %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| 386 | + %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) |
| 387 | + return [aten_gelu_default] |
| 388 | + This node _param_constant0 has metadata of: |
| 389 | + The node stacktrace: |
| 390 | + Traceback (most recent call last): |
| 391 | + File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward |
| 392 | + return self.test_model(x) |
| 393 | + File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl |
| 394 | + return forward_call(*args, **kwargs) |
| 395 | + File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward |
| 396 | + a = self.conv1(x) |
| 397 | +
|
| 398 | + """ |
| 399 | + graph_str_with_cursor = add_cursor_to_graph(graph, node) |
| 400 | + error_msg = ( |
| 401 | + f"Here is the node in the graph module:\n" |
| 402 | + f"{graph_str_with_cursor}\n" |
| 403 | + f"This node {node} has metadata of:\n" |
| 404 | + ) |
| 405 | + # Node spec error message |
| 406 | + if hasattr(node.meta, "spec"): |
| 407 | + error_msg += f"The node spec:\n{node.meta['spec']}\n" |
| 408 | + |
| 409 | + # Stacktrace error message |
| 410 | + if "stack_trace" in node.meta: |
| 411 | + framelist = _stacktrace_to_framelist(node.meta["stack_trace"]) |
| 412 | + error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n" |
| 413 | + return error_msg |
0 commit comments