Skip to content

Commit 3c7caea

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Optionally disable debug handle validateion
Summary: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. Reviewed By: Gasoonjia Differential Revision: D81784685
1 parent 5d63bad commit 3c7caea

File tree

3 files changed

+162
-16
lines changed

3 files changed

+162
-16
lines changed

devtools/inspector/_inspector.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,7 @@ def _consume_etrecord(self) -> None:
11691169

11701170
def _get_aot_intermediate_outputs_and_op_names(
11711171
self,
1172+
disable_debug_handle_valdiation: bool = False,
11721173
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11731174
"""
11741175
Capture intermediate outputs only if _representative_inputs are provided
@@ -1184,6 +1185,7 @@ def _get_aot_intermediate_outputs_and_op_names(
11841185
self._etrecord.exported_program,
11851186
self._etrecord.export_graph_id,
11861187
self._etrecord.edge_dialect_program,
1188+
disable_debug_handle_valdiation,
11871189
):
11881190
export_program = self._etrecord.exported_program
11891191
else:
@@ -1404,7 +1406,7 @@ def get_exported_program(
14041406
else self._etrecord.graph_map.get(graph)
14051407
)
14061408

1407-
def calculate_numeric_gap(self, distance: str = "MSE"):
1409+
def calculate_numeric_gap(self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False):
14081410
"""
14091411
Compares logged intermediate outputs from the exported graph (in ETRecord)
14101412
with runtime outputs (in ETDump) using a user-specific numerical comparator.
@@ -1416,12 +1418,17 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14161418
14171419
Args:
14181420
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1421+
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc.,
1422+
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection
1423+
between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding
1424+
node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This
1425+
flag allows one to override such behavior and make best effort comparison.
14191426
14201427
Returns:
14211428
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14221429
"""
14231430
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
1424-
self._get_aot_intermediate_outputs_and_op_names()
1431+
self._get_aot_intermediate_outputs_and_op_names(disable_debug_handle_valdiation)
14251432
)
14261433
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
14271434
raise ValueError(
@@ -1451,6 +1458,12 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14511458
) in mapping.items():
14521459
if aot_intermediate_output is None or runtime_intermediate_output is None:
14531460
continue
1461+
# If aot outputs length is > 1 then comparison fails since we dont really have
1462+
# any instances where runtime intermediate output is a tuple or list
1463+
# This does not happen when edge dialect program is reference for comparison
1464+
# but happens in aten graph where ops like unbind remain undecomposed
1465+
if isinstance(aot_intermediate_output, Sequence) and len(aot_intermediate_output) > 1:
1466+
continue
14541467
rows.append(
14551468
{
14561469
"aot_ops": find_op_names(

devtools/inspector/_inspector_utils.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657657
# Combine all AOT debug_handles into a list
658658
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
659659

660-
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
661-
# AOT combined debug_handle and runtime debug_handle do not match.
660+
# Reason we dont check for exact match:
661+
# in some experiments where we want to rewrite the aten graph that was
662+
# lowered, so as to use custom ops like int4_matmul, we lose some nodes
663+
# on the graph and thus lose some debug handles. And we dont find
664+
# exact match within connected components.
665+
if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)):
666+
# AOT combined debug_handle is not a subset of runtime debug_handle.
662667
return (-1,), None
663668

664669
# Pick the last intermediate output
665670
last_int = runtime_debug_handle[negative_index]
666671
key = (last_int,)
672+
if key not in aot_map:
673+
# If the last intermediate output is not in the AOT map, return None
674+
return (-1,), None
667675
return runtime_debug_handle, aot_map[key]
668676

669677

@@ -965,7 +973,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
965973
# Ensure both sequences have the same length
966974
if len(a) != len(b):
967975
raise ValueError(
968-
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
976+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
969977
)
970978

971979
# Compare each element in the sequences and return the list of results
@@ -990,6 +998,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]:
990998
Returns: the identifiers of all its ancestor nodes
991999
"""
9921000

1001+
if FROM_NODE_KEY not in node.meta:
1002+
return None
1003+
9931004
node_source = node.meta[FROM_NODE_KEY]
9941005
node_source = node_source[-1]
9951006
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
@@ -1056,11 +1067,14 @@ def _find_n_match_node(node: Node) -> None:
10561067
if node.op in ("output", "placeholder"):
10571068
return
10581069
node_id = f"{node.name}.{exported_program_graph_id}"
1059-
parent_node_id = get_parent_node_identifier(node)
1070+
parent_node_ids = get_ancestor_node_identifiers(node)
10601071
if node_id in ancestors_node_id_to_debug_handle:
10611072
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1062-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1063-
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1073+
elif parent_node_ids:
1074+
for parent_node_id in parent_node_ids:
1075+
if parent_node_id in ancestors_node_id_to_debug_handle:
1076+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1077+
break
10641078

10651079
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
10661080
return matched_debug_handles
@@ -1094,15 +1108,17 @@ def _equip_debug_handle(node: Node) -> None:
10941108
if node.op in ("output", "placeholder"):
10951109
return
10961110
node_id = f"{node.name}.{exported_program_graph_id}"
1097-
parent_node_id = get_parent_node_identifier(node)
1111+
parent_node_ids = get_ancestor_node_identifiers(node)
1112+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
10981113
if node_id in ancestors_node_id_to_debug_handle:
10991114
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1100-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1101-
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1102-
parent_node_id
1103-
]
1104-
else:
1105-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1115+
elif parent_node_ids:
1116+
for parent_node_id in parent_node_ids:
1117+
if parent_node_id in ancestors_node_id_to_debug_handle:
1118+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1119+
parent_node_id
1120+
]
1121+
break
11061122

11071123
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
11081124

@@ -1111,6 +1127,7 @@ def propagate_back_debug_handle(
11111127
exported_program: ExportedProgram,
11121128
exported_program_graph_id: int,
11131129
edge_dialect_program: ExportedProgram,
1130+
disable_debug_handle_valdiation: bool = False,
11141131
) -> bool:
11151132
"""
11161133
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
@@ -1124,6 +1141,10 @@ def propagate_back_debug_handle(
11241141
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
11251142
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
11261143
1144+
disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch.
1145+
This can happen when we are comparing against aten graph in which case not all debug handles are matched
1146+
in aten graph. Example of this is when symbolic shape nodes are re-exported.
1147+
11271148
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
11281149
"""
11291150
# 1. Extract mapping from ancestor node identifiers to debug handles
@@ -1137,7 +1158,7 @@ def propagate_back_debug_handle(
11371158
)
11381159

11391160
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1140-
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
1161+
if not disable_debug_handle_valdiation and not _verify_graph_match(edge_dialect_program, matched_debug_handles):
11411162
return False
11421163

11431164
# 4. Apply debug handles to the exported program

devtools/inspector/tests/inspector_test.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,118 @@ def _gen_random_runtime_output(
838838
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
839839
return [torch.randn(RAW_DATA_SIZE)]
840840

841+
def test_disable_debug_handle_validation_with_symbolic_shapes(self):
842+
"""
843+
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844+
during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845+
in propagate_back_debug_handle allows validation to be bypassed.
846+
"""
847+
from executorch.devtools.inspector._inspector_utils import propagate_back_debug_handle
848+
849+
class SymbolicShapeModel(torch.nn.Module):
850+
"""Model that will have symbolic shape related operations after export."""
851+
852+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
853+
# This will create symbolic shape nodes during dynamic export
854+
batch_size = x.shape[0]
855+
x = x + torch.rand((batch_size, 1))
856+
# Masking operation that creates gt/lt nodes
857+
valid_mask = mask > 0.5
858+
x = torch.where(valid_mask, x, torch.zeros_like(x))
859+
return x
860+
861+
# Create model and dynamic inputs
862+
model = SymbolicShapeModel()
863+
batch_size = 2
864+
seq_len = 4
865+
x = torch.randn(batch_size, seq_len)
866+
mask = torch.rand(batch_size, seq_len)
867+
example_inputs = (x, mask)
868+
869+
# Export with dynamic shapes to create symbolic shape related nodes
870+
dynamic_shapes = {
871+
"x": {0: torch.export.Dim("batch_size", min=1, max=10)},
872+
"mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
873+
}
874+
875+
exported_program = torch.export.export(
876+
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
877+
)
878+
879+
"""
880+
In this case origina aten graph has sym_size_int_2 node but when we look at
881+
nodes metadata in edge_program_manager, its sym_size node's from_node says
882+
sym_size_int_3 which is not in the original aten graph.
883+
"""
884+
# Create edge program - this is where from_node info can be lost for symbolic shape nodes
885+
edge_program_manager: EdgeProgramManager = to_edge(exported_program)
886+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
887+
et_program_manager: ExecutorchProgramManager = edge_program_manager.to_executorch()
888+
889+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
890+
etrecord_path = tmp_file.name
891+
892+
# Generate ETRecord with the exported program (aten graph)
893+
generate_etrecord(
894+
etrecord_path,
895+
edge_program_manager_copy,
896+
et_program_manager,
897+
exported_program=exported_program,
898+
)
899+
900+
# Create Inspector and get etrecord
901+
with patch.object(
902+
_inspector, "gen_etdump_object", return_value=None
903+
), patch.object(
904+
EventBlock, "_gen_from_etdump"
905+
):
906+
inspector_instance = Inspector(
907+
etdump_path=ETDUMP_PATH,
908+
etrecord=etrecord_path,
909+
)
910+
911+
# Extract the necessary values from the inspector's etrecord
912+
exported_program_from_etrecord = inspector_instance._etrecord.exported_program
913+
export_graph_id = inspector_instance._etrecord.export_graph_id
914+
edge_dialect_program = inspector_instance._etrecord.edge_dialect_program
915+
916+
# Ensure we have all the necessary components
917+
self.assertIsNotNone(exported_program_from_etrecord)
918+
self.assertIsNotNone(export_graph_id)
919+
self.assertIsNotNone(edge_dialect_program)
920+
921+
# Test propagate_back_debug_handle with validation enabled (should fail or return False)
922+
# This demonstrates the issue with symbolic shape nodes losing from_node info
923+
validation_enabled_result = propagate_back_debug_handle(
924+
exported_program_from_etrecord,
925+
export_graph_id,
926+
edge_dialect_program,
927+
disable_debug_handle_valdiation=False
928+
)
929+
930+
# With validation enabled, it should return False when from_node info is lost
931+
self.assertFalse(
932+
validation_enabled_result,
933+
"propagate_back_debug_handle should return False when validation is enabled "
934+
"and symbolic shape nodes lose from_node info"
935+
)
936+
937+
# Test propagate_back_debug_handle with validation disabled (should succeed)
938+
# This shows how the disable_debug_handle_valdiation flag allows the function to work
939+
validation_disabled_result = propagate_back_debug_handle(
940+
exported_program_from_etrecord,
941+
export_graph_id,
942+
edge_dialect_program,
943+
disable_debug_handle_valdiation=True
944+
)
945+
946+
# With validation disabled, it should return True even when from_node info is lost
947+
self.assertTrue(
948+
validation_disabled_result,
949+
"propagate_back_debug_handle should return True when validation is disabled, "
950+
"allowing best effort comparison even when from_node info is lost"
951+
)
952+
841953
def _gen_random_events(self) -> List[Event]:
842954
events = []
843955
for i in range(2):

0 commit comments

Comments
 (0)