Skip to content

Commit a57dcc7

Browse files
committed
Add import and export options to loss_compare.py
This allows us to check if the loss is consistent across commits/PRs ghstack-source-id: f4541cb Pull-Request: #2063
1 parent 13c0a27 commit a57dcc7

File tree

3 files changed

+157
-8
lines changed

3 files changed

+157
-8
lines changed

.github/workflows/integration_test_8gpu_features.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ jobs:
9393
python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
9494
9595
# Verify the accuracy.
96-
echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity"
96+
echo "Checking FSDP4 v.s. HSDP2FSDP4 accuracy parity"
9797
export baseline_options="--parallelism.data_parallel_replicate_degree=1"
98-
export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2"
99-
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1
98+
export test_options="--parallelism.data_parallel_replicate_degree=4"
99+
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result assets/ci_llama3_losses.txt
100100
101101
# Cleanup the checkpoints so that we don't waste network bandwidth and time.
102102
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint

assets/ci_llama3_losses.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
1 8.1376
2+
2 7.841
3+
3 7.1815
4+
4 6.3509
5+
5 5.5272
6+
6 4.9244
7+
7 4.5606
8+
8 4.3724
9+
9 4.347
10+
10 4.2001

scripts/loss_compare.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def validate_arguments(
168168
test_train_file: str,
169169
test_options: str,
170170
steps: int,
171+
assert_equal: bool,
172+
export_result: str | None,
173+
import_result: str | None,
171174
) -> None:
172175
"""Validate command line arguments."""
173176
# Validate commit arguments - if one is ".", both must be "."
@@ -201,6 +204,34 @@ def validate_arguments(
201204
log_print(f"Error: --steps must be a positive integer, got: {steps}")
202205
sys.exit(1)
203206

207+
# Validate export-result requires assert-equal
208+
if export_result and not assert_equal:
209+
log_print("Error: --export-result requires --assert-equal")
210+
log_print(" Export only happens when losses are verified to match")
211+
sys.exit(1)
212+
213+
# Validate import-result requires assert-equal
214+
if import_result and not assert_equal:
215+
log_print("Error: --import-result requires --assert-equal")
216+
log_print(" Import is used to verify all losses match")
217+
sys.exit(1)
218+
219+
# Validate export-result and import-result are mutually exclusive
220+
if export_result and import_result:
221+
log_print(
222+
"Error: --export-result and --import-result cannot be " "used together"
223+
)
224+
log_print(
225+
" Use export to save results or import to compare "
226+
"against saved results"
227+
)
228+
sys.exit(1)
229+
230+
# Validate import file exists
231+
if import_result and not os.path.exists(import_result):
232+
log_print(f"Error: Import file does not exist: {import_result}")
233+
sys.exit(1)
234+
204235

205236
# =============================================================================
206237
# SETUP FUNCTIONS
@@ -433,6 +464,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]:
433464
return losses
434465

435466

467+
def export_losses_to_file(losses: dict[int, float], export_path: str) -> None:
468+
"""Export losses to file and stdout.
469+
470+
Args:
471+
losses: Dictionary mapping step numbers to loss values
472+
export_path: Path to export file
473+
"""
474+
log_print(f"Exporting losses to {export_path}")
475+
476+
# Write to file and collect output for stdout
477+
with open(export_path, "w") as f:
478+
for step in sorted(losses.keys()):
479+
loss = losses[step]
480+
line = f"{step} {loss}"
481+
f.write(line + "\n")
482+
483+
log_print(f"Exported {len(losses)} loss values:")
484+
log_print()
485+
486+
# Output to stdout in same format
487+
for step in sorted(losses.keys()):
488+
loss = losses[step]
489+
print(f"{step} {loss}")
490+
491+
log_print()
492+
log_print(f"Losses saved to: {export_path}")
493+
494+
436495
def extract_loss_data(output_folder: str | None) -> None:
437496
"""Extract loss data from logs."""
438497
if not output_folder:
@@ -556,13 +615,18 @@ def perform_loss_analysis(
556615
generate_summary_statistics(baseline_losses, test_losses, stats_file)
557616

558617

559-
def assert_losses_equal(baseline_log: str, test_log: str) -> None:
560-
"""Assert that losses are equal between baseline and test using
561-
unittest.
618+
def assert_losses_equal(
619+
baseline_log: str, test_log: str, import_result: str | None = None
620+
) -> None:
621+
"""Assert that losses are equal between baseline and test using unittest.
622+
623+
If import_result is provided, also compares baseline with imported losses.
562624
"""
563625
log_print("Asserting losses are equal...")
564626
log_print(f"Baseline log: {baseline_log}")
565627
log_print(f"Test log: {test_log}")
628+
if import_result:
629+
log_print(f"Import file: {import_result}")
566630

567631
# Extract losses from both logs
568632
baseline_losses = extract_losses_from_log(baseline_log)
@@ -579,6 +643,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None:
579643
log_print("Error: No losses found in test log")
580644
sys.exit(1)
581645

646+
# Load imported losses if provided
647+
imported_losses = None
648+
if import_result:
649+
imported_losses = read_losses_from_file(import_result)
650+
log_print(f"Loaded {len(imported_losses)} steps from import file")
651+
if not imported_losses:
652+
log_print("Error: No losses found in import file")
653+
sys.exit(1)
654+
582655
# Create a test case
583656
class LossEqualityTest(unittest.TestCase):
584657
def test_losses_equal(self):
@@ -593,17 +666,41 @@ def test_losses_equal(self):
593666
f"test has {len(test_steps)} steps",
594667
)
595668

669+
# If imported losses exist, check steps match
670+
if imported_losses:
671+
imported_steps = set(imported_losses.keys())
672+
self.assertEqual(
673+
baseline_steps,
674+
imported_steps,
675+
f"Steps mismatch: baseline has {len(baseline_steps)} steps, "
676+
f"imported has {len(imported_steps)} steps",
677+
)
678+
596679
# Check that losses are equal for each step
597680
for step in sorted(baseline_steps):
598681
baseline_loss = baseline_losses[step]
599682
test_loss = test_losses[step]
683+
684+
# Compare baseline vs test
600685
self.assertEqual(
601686
baseline_loss,
602687
test_loss,
603688
f"Loss mismatch at step {step}: "
604689
f"baseline={baseline_loss}, test={test_loss}",
605690
)
606691

692+
# Compare baseline vs imported (if provided)
693+
# No need to compare test vs imported since:
694+
# baseline==test and baseline==imported implies test==imported
695+
if imported_losses:
696+
imported_loss = imported_losses[step]
697+
self.assertEqual(
698+
baseline_loss,
699+
imported_loss,
700+
f"Loss mismatch at step {step}: "
701+
f"baseline={baseline_loss}, imported={imported_loss}",
702+
)
703+
607704
# Run the test
608705
suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest)
609706
runner = unittest.TextTestRunner(verbosity=2)
@@ -613,7 +710,13 @@ def test_losses_equal(self):
613710
log_print("Loss assertion failed!")
614711
sys.exit(1)
615712
else:
616-
log_print("All losses are equal. Assertion passed!")
713+
if import_result:
714+
log_print(
715+
"All losses are equal (baseline, test, and imported). "
716+
"Assertion passed!"
717+
)
718+
else:
719+
log_print("All losses are equal. Assertion passed!")
617720

618721

619722
def cleanup_temp_files(output_folder: str | None) -> None:
@@ -756,6 +859,24 @@ def parse_arguments() -> argparse.Namespace:
756859
"Script exits with error if losses differ."
757860
),
758861
)
862+
parser.add_argument(
863+
"--export-result",
864+
default="",
865+
help=(
866+
"Export losses to specified file path (requires --assert-equal). "
867+
"Exports only when losses match. Format: '{step} {loss}' per line."
868+
),
869+
)
870+
parser.add_argument(
871+
"--import-result",
872+
default="",
873+
help=(
874+
"Import losses from specified file path for comparison "
875+
"(requires --assert-equal). "
876+
"Compares imported losses with both baseline and test "
877+
"(all 3 must match)."
878+
),
879+
)
759880
parser.add_argument(
760881
"--job-dump-folder",
761882
default="outputs",
@@ -787,6 +908,14 @@ def parse_arguments() -> argparse.Namespace:
787908
if not args.output_folder:
788909
args.output_folder = None
789910

911+
# Convert empty export_result to None
912+
if not args.export_result:
913+
args.export_result = None
914+
915+
# Convert empty import_result to None
916+
if not args.import_result:
917+
args.import_result = None
918+
790919
return args
791920

792921

@@ -850,6 +979,9 @@ def main() -> None:
850979
args.test_train_file,
851980
args.test_options,
852981
args.steps,
982+
args.assert_equal,
983+
args.export_result,
984+
args.import_result,
853985
)
854986

855987
# Setup environment
@@ -912,7 +1044,14 @@ def main() -> None:
9121044

9131045
# Assert losses are equal if requested
9141046
if args.assert_equal:
915-
assert_losses_equal(baseline_log, test_log)
1047+
# Pass import_result if provided for 3-way comparison
1048+
assert_losses_equal(baseline_log, test_log, args.import_result)
1049+
1050+
# Export losses if requested (only after assertion passes)
1051+
if args.export_result:
1052+
# Extract baseline losses (they equal test losses since assertion passed)
1053+
baseline_losses = extract_losses_from_log(baseline_log)
1054+
export_losses_to_file(baseline_losses, args.export_result)
9161055

9171056
# Analysis and reporting
9181057
perform_loss_analysis(baseline_log, test_log, stats_file)

0 commit comments

Comments
 (0)