@@ -168,6 +168,9 @@ def validate_arguments(
168168 test_train_file : str ,
169169 test_options : str ,
170170 steps : int ,
171+ assert_equal : bool ,
172+ export_result : str | None ,
173+ import_result : str | None ,
171174) -> None :
172175 """Validate command line arguments."""
173176 # Validate commit arguments - if one is ".", both must be "."
@@ -201,6 +204,34 @@ def validate_arguments(
201204 log_print (f"Error: --steps must be a positive integer, got: { steps } " )
202205 sys .exit (1 )
203206
207+ # Validate export-result requires assert-equal
208+ if export_result and not assert_equal :
209+ log_print ("Error: --export-result requires --assert-equal" )
210+ log_print (" Export only happens when losses are verified to match" )
211+ sys .exit (1 )
212+
213+ # Validate import-result requires assert-equal
214+ if import_result and not assert_equal :
215+ log_print ("Error: --import-result requires --assert-equal" )
216+ log_print (" Import is used to verify all losses match" )
217+ sys .exit (1 )
218+
219+ # Validate export-result and import-result are mutually exclusive
220+ if export_result and import_result :
221+ log_print (
222+ "Error: --export-result and --import-result cannot be " "used together"
223+ )
224+ log_print (
225+ " Use export to save results or import to compare "
226+ "against saved results"
227+ )
228+ sys .exit (1 )
229+
230+ # Validate import file exists
231+ if import_result and not os .path .exists (import_result ):
232+ log_print (f"Error: Import file does not exist: { import_result } " )
233+ sys .exit (1 )
234+
204235
205236# =============================================================================
206237# SETUP FUNCTIONS
@@ -433,6 +464,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]:
433464 return losses
434465
435466
467+ def export_losses_to_file (losses : dict [int , float ], export_path : str ) -> None :
468+ """Export losses to file and stdout.
469+
470+ Args:
471+ losses: Dictionary mapping step numbers to loss values
472+ export_path: Path to export file
473+ """
474+ log_print (f"Exporting losses to { export_path } " )
475+
476+ # Write to file and collect output for stdout
477+ with open (export_path , "w" ) as f :
478+ for step in sorted (losses .keys ()):
479+ loss = losses [step ]
480+ line = f"{ step } { loss } "
481+ f .write (line + "\n " )
482+
483+ log_print (f"Exported { len (losses )} loss values:" )
484+ log_print ()
485+
486+ # Output to stdout in same format
487+ for step in sorted (losses .keys ()):
488+ loss = losses [step ]
489+ print (f"{ step } { loss } " )
490+
491+ log_print ()
492+ log_print (f"Losses saved to: { export_path } " )
493+
494+
436495def extract_loss_data (output_folder : str | None ) -> None :
437496 """Extract loss data from logs."""
438497 if not output_folder :
@@ -556,13 +615,18 @@ def perform_loss_analysis(
556615 generate_summary_statistics (baseline_losses , test_losses , stats_file )
557616
558617
559- def assert_losses_equal (baseline_log : str , test_log : str ) -> None :
560- """Assert that losses are equal between baseline and test using
561- unittest.
618+ def assert_losses_equal (
619+ baseline_log : str , test_log : str , import_result : str | None = None
620+ ) -> None :
621+ """Assert that losses are equal between baseline and test using unittest.
622+
623+ If import_result is provided, also compares baseline with imported losses.
562624 """
563625 log_print ("Asserting losses are equal..." )
564626 log_print (f"Baseline log: { baseline_log } " )
565627 log_print (f"Test log: { test_log } " )
628+ if import_result :
629+ log_print (f"Import file: { import_result } " )
566630
567631 # Extract losses from both logs
568632 baseline_losses = extract_losses_from_log (baseline_log )
@@ -579,6 +643,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None:
579643 log_print ("Error: No losses found in test log" )
580644 sys .exit (1 )
581645
646+ # Load imported losses if provided
647+ imported_losses = None
648+ if import_result :
649+ imported_losses = read_losses_from_file (import_result )
650+ log_print (f"Loaded { len (imported_losses )} steps from import file" )
651+ if not imported_losses :
652+ log_print ("Error: No losses found in import file" )
653+ sys .exit (1 )
654+
582655 # Create a test case
583656 class LossEqualityTest (unittest .TestCase ):
584657 def test_losses_equal (self ):
@@ -593,17 +666,41 @@ def test_losses_equal(self):
593666 f"test has { len (test_steps )} steps" ,
594667 )
595668
669+ # If imported losses exist, check steps match
670+ if imported_losses :
671+ imported_steps = set (imported_losses .keys ())
672+ self .assertEqual (
673+ baseline_steps ,
674+ imported_steps ,
675+ f"Steps mismatch: baseline has { len (baseline_steps )} steps, "
676+ f"imported has { len (imported_steps )} steps" ,
677+ )
678+
596679 # Check that losses are equal for each step
597680 for step in sorted (baseline_steps ):
598681 baseline_loss = baseline_losses [step ]
599682 test_loss = test_losses [step ]
683+
684+ # Compare baseline vs test
600685 self .assertEqual (
601686 baseline_loss ,
602687 test_loss ,
603688 f"Loss mismatch at step { step } : "
604689 f"baseline={ baseline_loss } , test={ test_loss } " ,
605690 )
606691
692+ # Compare baseline vs imported (if provided)
693+ # No need to compare test vs imported since:
694+ # baseline==test and baseline==imported implies test==imported
695+ if imported_losses :
696+ imported_loss = imported_losses [step ]
697+ self .assertEqual (
698+ baseline_loss ,
699+ imported_loss ,
700+ f"Loss mismatch at step { step } : "
701+ f"baseline={ baseline_loss } , imported={ imported_loss } " ,
702+ )
703+
607704 # Run the test
608705 suite = unittest .TestLoader ().loadTestsFromTestCase (LossEqualityTest )
609706 runner = unittest .TextTestRunner (verbosity = 2 )
@@ -613,7 +710,13 @@ def test_losses_equal(self):
613710 log_print ("Loss assertion failed!" )
614711 sys .exit (1 )
615712 else :
616- log_print ("All losses are equal. Assertion passed!" )
713+ if import_result :
714+ log_print (
715+ "All losses are equal (baseline, test, and imported). "
716+ "Assertion passed!"
717+ )
718+ else :
719+ log_print ("All losses are equal. Assertion passed!" )
617720
618721
619722def cleanup_temp_files (output_folder : str | None ) -> None :
@@ -756,6 +859,24 @@ def parse_arguments() -> argparse.Namespace:
756859 "Script exits with error if losses differ."
757860 ),
758861 )
862+ parser .add_argument (
863+ "--export-result" ,
864+ default = "" ,
865+ help = (
866+ "Export losses to specified file path (requires --assert-equal). "
867+ "Exports only when losses match. Format: '{step} {loss}' per line."
868+ ),
869+ )
870+ parser .add_argument (
871+ "--import-result" ,
872+ default = "" ,
873+ help = (
874+ "Import losses from specified file path for comparison "
875+ "(requires --assert-equal). "
876+ "Compares imported losses with both baseline and test "
877+ "(all 3 must match)."
878+ ),
879+ )
759880 parser .add_argument (
760881 "--job-dump-folder" ,
761882 default = "outputs" ,
@@ -787,6 +908,14 @@ def parse_arguments() -> argparse.Namespace:
787908 if not args .output_folder :
788909 args .output_folder = None
789910
911+ # Convert empty export_result to None
912+ if not args .export_result :
913+ args .export_result = None
914+
915+ # Convert empty import_result to None
916+ if not args .import_result :
917+ args .import_result = None
918+
790919 return args
791920
792921
@@ -850,6 +979,9 @@ def main() -> None:
850979 args .test_train_file ,
851980 args .test_options ,
852981 args .steps ,
982+ args .assert_equal ,
983+ args .export_result ,
984+ args .import_result ,
853985 )
854986
855987 # Setup environment
@@ -912,7 +1044,14 @@ def main() -> None:
9121044
9131045 # Assert losses are equal if requested
9141046 if args .assert_equal :
915- assert_losses_equal (baseline_log , test_log )
1047+ # Pass import_result if provided for 3-way comparison
1048+ assert_losses_equal (baseline_log , test_log , args .import_result )
1049+
1050+ # Export losses if requested (only after assertion passes)
1051+ if args .export_result :
1052+ # Extract baseline losses (they equal test losses since assertion passed)
1053+ baseline_losses = extract_losses_from_log (baseline_log )
1054+ export_losses_to_file (baseline_losses , args .export_result )
9161055
9171056 # Analysis and reporting
9181057 perform_loss_analysis (baseline_log , test_log , stats_file )
0 commit comments