@@ -168,6 +168,9 @@ def validate_arguments(
168168 test_train_file : str ,
169169 test_options : str ,
170170 steps : int ,
171+ assert_equal : bool ,
172+ export_result : str | None ,
173+ import_result : str | None ,
171174) -> None :
172175 """Validate command line arguments."""
173176 # Validate commit arguments - if one is ".", both must be "."
@@ -201,6 +204,34 @@ def validate_arguments(
201204 log_print (f"Error: --steps must be a positive integer, got: { steps } " )
202205 sys .exit (1 )
203206
207+ # Validate export-result requires assert-equal
208+ if export_result and not assert_equal :
209+ log_print ("Error: --export-result requires --assert-equal" )
210+ log_print (" Export only happens when losses are verified to match" )
211+ sys .exit (1 )
212+
213+ # Validate import-result requires assert-equal
214+ if import_result and not assert_equal :
215+ log_print ("Error: --import-result requires --assert-equal" )
216+ log_print (" Import is used to verify all losses match" )
217+ sys .exit (1 )
218+
219+ # Validate export-result and import-result are mutually exclusive
220+ if export_result and import_result :
221+ log_print (
222+ "Error: --export-result and --import-result cannot be " "used together"
223+ )
224+ log_print (
225+ " Use export to save results or import to compare "
226+ "against saved results"
227+ )
228+ sys .exit (1 )
229+
230+ # Validate import file exists
231+ if import_result and not os .path .exists (import_result ):
232+ log_print (f"Error: Import file does not exist: { import_result } " )
233+ sys .exit (1 )
234+
204235
205236# =============================================================================
206237# SETUP FUNCTIONS
@@ -321,7 +352,10 @@ def check_git_clean_state() -> None:
321352 for line in result .stdout .strip ().split ("\n " ):
322353 log_print (f" { line } " )
323354 log_print ("" )
324- log_print ("Please commit, stash, or discard your changes before running this script" )
355+ log_print (
356+ "Please commit, stash, or discard your changes before "
357+ "running this script"
358+ )
325359 log_print (" - To commit: git add -A && git commit -m 'message'" )
326360 log_print (" - To stash: git stash" )
327361 log_print (" - To discard: git checkout -- . && git clean -fd" )
@@ -431,6 +465,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]:
431465 return losses
432466
433467
468+ def export_losses_to_file (losses : dict [int , float ], export_path : str ) -> None :
469+ """Export losses to file and stdout.
470+
471+ Args:
472+ losses: Dictionary mapping step numbers to loss values
473+ export_path: Path to export file
474+ """
475+ log_print (f"Exporting losses to { export_path } " )
476+
477+ # Write to file and collect output for stdout
478+ with open (export_path , "w" ) as f :
479+ for step in sorted (losses .keys ()):
480+ loss = losses [step ]
481+ line = f"{ step } { loss } "
482+ f .write (line + "\n " )
483+
484+ log_print (f"Exported { len (losses )} loss values:" )
485+ log_print ()
486+
487+ # Output to stdout in same format
488+ for step in sorted (losses .keys ()):
489+ loss = losses [step ]
490+ print (f"{ step } { loss } " )
491+
492+ log_print ()
493+ log_print (f"Losses saved to: { export_path } " )
494+
495+
434496def extract_loss_data (output_folder : str | None ) -> None :
435497 """Extract loss data from logs."""
436498 if not output_folder :
@@ -554,13 +616,18 @@ def perform_loss_analysis(
554616 generate_summary_statistics (baseline_losses , test_losses , stats_file )
555617
556618
557- def assert_losses_equal (baseline_log : str , test_log : str ) -> None :
558- """Assert that losses are equal between baseline and test using
559- unittest.
619+ def assert_losses_equal (
620+ baseline_log : str , test_log : str , import_result : str | None = None
621+ ) -> None :
622+ """Assert that losses are equal between baseline and test using unittest.
623+
624+ If import_result is provided, also compares baseline with imported losses.
560625 """
561626 log_print ("Asserting losses are equal..." )
562627 log_print (f"Baseline log: { baseline_log } " )
563628 log_print (f"Test log: { test_log } " )
629+ if import_result :
630+ log_print (f"Import file: { import_result } " )
564631
565632 # Extract losses from both logs
566633 baseline_losses = extract_losses_from_log (baseline_log )
@@ -577,6 +644,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None:
577644 log_print ("Error: No losses found in test log" )
578645 sys .exit (1 )
579646
647+ # Load imported losses if provided
648+ imported_losses = None
649+ if import_result :
650+ imported_losses = read_losses_from_file (import_result )
651+ log_print (f"Loaded { len (imported_losses )} steps from import file" )
652+ if not imported_losses :
653+ log_print ("Error: No losses found in import file" )
654+ sys .exit (1 )
655+
580656 # Create a test case
581657 class LossEqualityTest (unittest .TestCase ):
582658 def test_losses_equal (self ):
@@ -591,17 +667,41 @@ def test_losses_equal(self):
591667 f"test has { len (test_steps )} steps" ,
592668 )
593669
670+ # If imported losses exist, check steps match
671+ if imported_losses :
672+ imported_steps = set (imported_losses .keys ())
673+ self .assertEqual (
674+ baseline_steps ,
675+ imported_steps ,
676+ f"Steps mismatch: baseline has { len (baseline_steps )} steps, "
677+ f"imported has { len (imported_steps )} steps" ,
678+ )
679+
594680 # Check that losses are equal for each step
595681 for step in sorted (baseline_steps ):
596682 baseline_loss = baseline_losses [step ]
597683 test_loss = test_losses [step ]
684+
685+ # Compare baseline vs test
598686 self .assertEqual (
599687 baseline_loss ,
600688 test_loss ,
601689 f"Loss mismatch at step { step } : "
602690 f"baseline={ baseline_loss } , test={ test_loss } " ,
603691 )
604692
693+ # Compare baseline vs imported (if provided)
694+ # No need to compare test vs imported since:
695+ # baseline==test and baseline==imported implies test==imported
696+ if imported_losses :
697+ imported_loss = imported_losses [step ]
698+ self .assertEqual (
699+ baseline_loss ,
700+ imported_loss ,
701+ f"Loss mismatch at step { step } : "
702+ f"baseline={ baseline_loss } , imported={ imported_loss } " ,
703+ )
704+
605705 # Run the test
606706 suite = unittest .TestLoader ().loadTestsFromTestCase (LossEqualityTest )
607707 runner = unittest .TextTestRunner (verbosity = 2 )
@@ -611,7 +711,13 @@ def test_losses_equal(self):
611711 log_print ("Loss assertion failed!" )
612712 sys .exit (1 )
613713 else :
614- log_print ("All losses are equal. Assertion passed!" )
714+ if import_result :
715+ log_print (
716+ "All losses are equal (baseline, test, and imported). "
717+ "Assertion passed!"
718+ )
719+ else :
720+ log_print ("All losses are equal. Assertion passed!" )
615721
616722
617723def cleanup_temp_files (output_folder : str | None ) -> None :
@@ -754,6 +860,24 @@ def parse_arguments() -> argparse.Namespace:
754860 "Script exits with error if losses differ."
755861 ),
756862 )
863+ parser .add_argument (
864+ "--export-result" ,
865+ default = "" ,
866+ help = (
867+ "Export losses to specified file path (requires --assert-equal). "
868+ "Exports only when losses match. Format: '{step} {loss}' per line."
869+ ),
870+ )
871+ parser .add_argument (
872+ "--import-result" ,
873+ default = "" ,
874+ help = (
875+ "Import losses from specified file path for comparison "
876+ "(requires --assert-equal). "
877+ "Compares imported losses with both baseline and test "
878+ "(all 3 must match)."
879+ ),
880+ )
757881 parser .add_argument (
758882 "--job-dump-folder" ,
759883 default = "outputs" ,
@@ -785,6 +909,14 @@ def parse_arguments() -> argparse.Namespace:
785909 if not args .output_folder :
786910 args .output_folder = None
787911
912+ # Convert empty export_result to None
913+ if not args .export_result :
914+ args .export_result = None
915+
916+ # Convert empty import_result to None
917+ if not args .import_result :
918+ args .import_result = None
919+
788920 return args
789921
790922
@@ -848,6 +980,9 @@ def main() -> None:
848980 args .test_train_file ,
849981 args .test_options ,
850982 args .steps ,
983+ args .assert_equal ,
984+ args .export_result ,
985+ args .import_result ,
851986 )
852987
853988 # Setup environment
@@ -910,7 +1045,14 @@ def main() -> None:
9101045
9111046 # Assert losses are equal if requested
9121047 if args .assert_equal :
913- assert_losses_equal (baseline_log , test_log )
1048+ # Pass import_result if provided for 3-way comparison
1049+ assert_losses_equal (baseline_log , test_log , args .import_result )
1050+
1051+ # Export losses if requested (only after assertion passes)
1052+ if args .export_result :
1053+ # Extract baseline losses (they equal test losses since assertion passed)
1054+ baseline_losses = extract_losses_from_log (baseline_log )
1055+ export_losses_to_file (baseline_losses , args .export_result )
9141056
9151057 # Analysis and reporting
9161058 perform_loss_analysis (baseline_log , test_log , stats_file )
0 commit comments