|
17 | 17 |
|
18 | 18 | from unittest.mock import patch
|
19 | 19 |
|
| 20 | +import pandas as pd |
| 21 | + |
20 | 22 | import torch
|
21 | 23 | import torch.fx
|
22 | 24 |
|
@@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self):
|
578 | 580 | self.assertIn((key,), runtime_outputs)
|
579 | 581 | self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
|
580 | 582 |
|
| 583 | + def test_calculate_numeric_gap(self): |
| 584 | + # Create a context manager to patch functions called by Inspector.__init__ |
| 585 | + with patch.object( |
| 586 | + _inspector, "parse_etrecord", return_value=None |
| 587 | + ), patch.object( |
| 588 | + _inspector, "gen_etdump_object", return_value=None |
| 589 | + ), patch.object( |
| 590 | + EventBlock, "_gen_from_etdump" |
| 591 | + ), patch.object( |
| 592 | + _inspector, "gen_graphs_from_etrecord" |
| 593 | + ): |
| 594 | + # Call the constructor of Inspector |
| 595 | + inspector_instance = Inspector( |
| 596 | + etdump_path=ETDUMP_PATH, |
| 597 | + etrecord=ETRECORD_PATH, |
| 598 | + ) |
| 599 | + |
| 600 | + aot_intermediate_outputs = { |
| 601 | + (0,): torch.tensor([1.0, 2.0, 3.0]), |
| 602 | + (1,): torch.tensor([4.0, 5.0, 6.0]), |
| 603 | + } |
| 604 | + |
| 605 | + runtime_intermediate_outputs = { |
| 606 | + (0,): torch.tensor([2.0, 1.0, 4.0]), |
| 607 | + (1,): torch.tensor([3.0, 6.0, 5.0]), |
| 608 | + } |
| 609 | + |
| 610 | + inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs |
| 611 | + inspector_instance._get_runtime_intermediate_outputs = ( |
| 612 | + lambda: runtime_intermediate_outputs |
| 613 | + ) |
| 614 | + |
| 615 | + df = inspector_instance.calculate_numeric_gap(distance="L1") |
| 616 | + self.assertIsInstance(df, pd.DataFrame) |
| 617 | + self.assertEqual(len(df), 2) |
| 618 | + cols = set(df.columns) |
| 619 | + expected_cols = { |
| 620 | + "aot_debug_handle", |
| 621 | + "aot_intermediate_output", |
| 622 | + "runtime_debug_handle", |
| 623 | + "runtime_intermediate_output", |
| 624 | + "gap", |
| 625 | + } |
| 626 | + self.assertEqual(cols, expected_cols) |
| 627 | + founded_aot_debug_handle = set(df["aot_debug_handle"]) |
| 628 | + self.assertEqual( |
| 629 | + founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) |
| 630 | + ) |
| 631 | + for _, row in df.iterrows(): |
| 632 | + aot_debuh_handle = row["aot_debug_handle"] |
| 633 | + # aot_intermediate_output should equal aot_intermediate_outputs[h] |
| 634 | + self.assertTrue( |
| 635 | + torch.allclose( |
| 636 | + row["aot_intermediate_output"], |
| 637 | + aot_intermediate_outputs[aot_debuh_handle], |
| 638 | + ) |
| 639 | + ) |
| 640 | + # runtime_debug_hanlde equals aot_debug_handle at this case |
| 641 | + self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) |
| 642 | + # runtime_intermediate_output should equal runtime_intermediate_outputs[h] |
| 643 | + self.assertTrue( |
| 644 | + torch.allclose( |
| 645 | + row["runtime_intermediate_output"], |
| 646 | + runtime_intermediate_outputs[aot_debuh_handle], |
| 647 | + ) |
| 648 | + ) |
| 649 | + # gap should equal 3.0 |
| 650 | + self.assertEqual(row["gap"], 3.0) |
| 651 | + |
581 | 652 | def _gen_random_float_list(self) -> List[float]:
|
582 | 653 | return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
|
583 | 654 |
|
|
0 commit comments