Skip to content

Commit 1492cc6

Browse files
committed
Add import and export options to loss_compare.py
This allows us to check if the loss is consistent across commits/PRs ghstack-source-id: 4715975 Pull-Request: #2063
1 parent cb4cf9d commit 1492cc6

File tree

2 files changed

+151
-9
lines changed

2 files changed

+151
-9
lines changed

.github/workflows/integration_test_8gpu_features.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ jobs:
9393
python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
9494
9595
# Verify the accuracy.
96-
echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity"
96+
echo "Checking FSDP4 v.s. HSDP2FSDP4 accuracy parity"
9797
export baseline_options="--parallelism.data_parallel_replicate_degree=1"
98-
export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2"
99-
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1
98+
export test_options="--parallelism.data_parallel_replicate_degree=4"
99+
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --export-result "${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs/loss_compare_result.txt"
100100
101101
# Cleanup the checkpoints so that we don't waste network bandwidth and time.
102102
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint

scripts/loss_compare.py

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def validate_arguments(
168168
test_train_file: str,
169169
test_options: str,
170170
steps: int,
171+
assert_equal: bool,
172+
export_result: str | None,
173+
import_result: str | None,
171174
) -> None:
172175
"""Validate command line arguments."""
173176
# Validate commit arguments - if one is ".", both must be "."
@@ -201,6 +204,34 @@ def validate_arguments(
201204
log_print(f"Error: --steps must be a positive integer, got: {steps}")
202205
sys.exit(1)
203206

207+
# Validate export-result requires assert-equal
208+
if export_result and not assert_equal:
209+
log_print("Error: --export-result requires --assert-equal")
210+
log_print(" Export only happens when losses are verified to match")
211+
sys.exit(1)
212+
213+
# Validate import-result requires assert-equal
214+
if import_result and not assert_equal:
215+
log_print("Error: --import-result requires --assert-equal")
216+
log_print(" Import is used to verify all losses match")
217+
sys.exit(1)
218+
219+
# Validate export-result and import-result are mutually exclusive
220+
if export_result and import_result:
221+
log_print(
222+
"Error: --export-result and --import-result cannot be " "used together"
223+
)
224+
log_print(
225+
" Use export to save results or import to compare "
226+
"against saved results"
227+
)
228+
sys.exit(1)
229+
230+
# Validate import file exists
231+
if import_result and not os.path.exists(import_result):
232+
log_print(f"Error: Import file does not exist: {import_result}")
233+
sys.exit(1)
234+
204235

205236
# =============================================================================
206237
# SETUP FUNCTIONS
@@ -321,7 +352,10 @@ def check_git_clean_state() -> None:
321352
for line in result.stdout.strip().split("\n"):
322353
log_print(f" {line}")
323354
log_print("")
324-
log_print("Please commit, stash, or discard your changes before running this script")
355+
log_print(
356+
"Please commit, stash, or discard your changes before "
357+
"running this script"
358+
)
325359
log_print(" - To commit: git add -A && git commit -m 'message'")
326360
log_print(" - To stash: git stash")
327361
log_print(" - To discard: git checkout -- . && git clean -fd")
@@ -431,6 +465,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]:
431465
return losses
432466

433467

468+
def export_losses_to_file(losses: dict[int, float], export_path: str) -> None:
469+
"""Export losses to file and stdout.
470+
471+
Args:
472+
losses: Dictionary mapping step numbers to loss values
473+
export_path: Path to export file
474+
"""
475+
log_print(f"Exporting losses to {export_path}")
476+
477+
# Write to file and collect output for stdout
478+
with open(export_path, "w") as f:
479+
for step in sorted(losses.keys()):
480+
loss = losses[step]
481+
line = f"{step} {loss}"
482+
f.write(line + "\n")
483+
484+
log_print(f"Exported {len(losses)} loss values:")
485+
log_print()
486+
487+
# Output to stdout in same format
488+
for step in sorted(losses.keys()):
489+
loss = losses[step]
490+
print(f"{step} {loss}")
491+
492+
log_print()
493+
log_print(f"Losses saved to: {export_path}")
494+
495+
434496
def extract_loss_data(output_folder: str | None) -> None:
435497
"""Extract loss data from logs."""
436498
if not output_folder:
@@ -554,13 +616,18 @@ def perform_loss_analysis(
554616
generate_summary_statistics(baseline_losses, test_losses, stats_file)
555617

556618

557-
def assert_losses_equal(baseline_log: str, test_log: str) -> None:
558-
"""Assert that losses are equal between baseline and test using
559-
unittest.
619+
def assert_losses_equal(
620+
baseline_log: str, test_log: str, import_result: str | None = None
621+
) -> None:
622+
"""Assert that losses are equal between baseline and test using unittest.
623+
624+
If import_result is provided, also compares baseline with imported losses.
560625
"""
561626
log_print("Asserting losses are equal...")
562627
log_print(f"Baseline log: {baseline_log}")
563628
log_print(f"Test log: {test_log}")
629+
if import_result:
630+
log_print(f"Import file: {import_result}")
564631

565632
# Extract losses from both logs
566633
baseline_losses = extract_losses_from_log(baseline_log)
@@ -577,6 +644,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None:
577644
log_print("Error: No losses found in test log")
578645
sys.exit(1)
579646

647+
# Load imported losses if provided
648+
imported_losses = None
649+
if import_result:
650+
imported_losses = read_losses_from_file(import_result)
651+
log_print(f"Loaded {len(imported_losses)} steps from import file")
652+
if not imported_losses:
653+
log_print("Error: No losses found in import file")
654+
sys.exit(1)
655+
580656
# Create a test case
581657
class LossEqualityTest(unittest.TestCase):
582658
def test_losses_equal(self):
@@ -591,17 +667,41 @@ def test_losses_equal(self):
591667
f"test has {len(test_steps)} steps",
592668
)
593669

670+
# If imported losses exist, check steps match
671+
if imported_losses:
672+
imported_steps = set(imported_losses.keys())
673+
self.assertEqual(
674+
baseline_steps,
675+
imported_steps,
676+
f"Steps mismatch: baseline has {len(baseline_steps)} steps, "
677+
f"imported has {len(imported_steps)} steps",
678+
)
679+
594680
# Check that losses are equal for each step
595681
for step in sorted(baseline_steps):
596682
baseline_loss = baseline_losses[step]
597683
test_loss = test_losses[step]
684+
685+
# Compare baseline vs test
598686
self.assertEqual(
599687
baseline_loss,
600688
test_loss,
601689
f"Loss mismatch at step {step}: "
602690
f"baseline={baseline_loss}, test={test_loss}",
603691
)
604692

693+
# Compare baseline vs imported (if provided)
694+
# No need to compare test vs imported since:
695+
# baseline==test and baseline==imported implies test==imported
696+
if imported_losses:
697+
imported_loss = imported_losses[step]
698+
self.assertEqual(
699+
baseline_loss,
700+
imported_loss,
701+
f"Loss mismatch at step {step}: "
702+
f"baseline={baseline_loss}, imported={imported_loss}",
703+
)
704+
605705
# Run the test
606706
suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest)
607707
runner = unittest.TextTestRunner(verbosity=2)
@@ -611,7 +711,13 @@ def test_losses_equal(self):
611711
log_print("Loss assertion failed!")
612712
sys.exit(1)
613713
else:
614-
log_print("All losses are equal. Assertion passed!")
714+
if import_result:
715+
log_print(
716+
"All losses are equal (baseline, test, and imported). "
717+
"Assertion passed!"
718+
)
719+
else:
720+
log_print("All losses are equal. Assertion passed!")
615721

616722

617723
def cleanup_temp_files(output_folder: str | None) -> None:
@@ -754,6 +860,24 @@ def parse_arguments() -> argparse.Namespace:
754860
"Script exits with error if losses differ."
755861
),
756862
)
863+
parser.add_argument(
864+
"--export-result",
865+
default="",
866+
help=(
867+
"Export losses to specified file path (requires --assert-equal). "
868+
"Exports only when losses match. Format: '{step} {loss}' per line."
869+
),
870+
)
871+
parser.add_argument(
872+
"--import-result",
873+
default="",
874+
help=(
875+
"Import losses from specified file path for comparison "
876+
"(requires --assert-equal). "
877+
"Compares imported losses with both baseline and test "
878+
"(all 3 must match)."
879+
),
880+
)
757881
parser.add_argument(
758882
"--job-dump-folder",
759883
default="outputs",
@@ -785,6 +909,14 @@ def parse_arguments() -> argparse.Namespace:
785909
if not args.output_folder:
786910
args.output_folder = None
787911

912+
# Convert empty export_result to None
913+
if not args.export_result:
914+
args.export_result = None
915+
916+
# Convert empty import_result to None
917+
if not args.import_result:
918+
args.import_result = None
919+
788920
return args
789921

790922

@@ -848,6 +980,9 @@ def main() -> None:
848980
args.test_train_file,
849981
args.test_options,
850982
args.steps,
983+
args.assert_equal,
984+
args.export_result,
985+
args.import_result,
851986
)
852987

853988
# Setup environment
@@ -910,7 +1045,14 @@ def main() -> None:
9101045

9111046
# Assert losses are equal if requested
9121047
if args.assert_equal:
913-
assert_losses_equal(baseline_log, test_log)
1048+
# Pass import_result if provided for 3-way comparison
1049+
assert_losses_equal(baseline_log, test_log, args.import_result)
1050+
1051+
# Export losses if requested (only after assertion passes)
1052+
if args.export_result:
1053+
# Extract baseline losses (they equal test losses since assertion passed)
1054+
baseline_losses = extract_losses_from_log(baseline_log)
1055+
export_losses_to_file(baseline_losses, args.export_result)
9141056

9151057
# Analysis and reporting
9161058
perform_loss_analysis(baseline_log, test_log, stats_file)

0 commit comments

Comments
 (0)