diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index de0672eeef..a20cd22545 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -90,13 +90,14 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 - - # Verify the accuracy. - echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" + # Verify the accuracy first. + echo "Checking FSDP8 v.s. HSDP (4, 2) accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" - export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 + export test_options="--parallelism.data_parallel_replicate_degree=4" + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result tests/assets/losses/llama3.txt + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/* + + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 42ad3a81be..3479875036 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -168,6 +168,9 @@ def validate_arguments( test_train_file: str, test_options: str, steps: int, + assert_equal: bool, + export_result: str | None, + import_result: str | None, ) -> None: """Validate command line arguments.""" # Validate commit arguments - if one is ".", both must be "." @@ -201,6 +204,34 @@ def validate_arguments( log_print(f"Error: --steps must be a positive integer, got: {steps}") sys.exit(1) + # Validate export-result requires assert-equal + if export_result and not assert_equal: + log_print("Error: --export-result requires --assert-equal") + log_print(" Export only happens when losses are verified to match") + sys.exit(1) + + # Validate import-result requires assert-equal + if import_result and not assert_equal: + log_print("Error: --import-result requires --assert-equal") + log_print(" Import is used to verify all losses match") + sys.exit(1) + + # Validate export-result and import-result are mutually exclusive + if export_result and import_result: + log_print( + "Error: --export-result and --import-result cannot be " "used together" + ) + log_print( + " Use export to save results or import to compare " + "against saved results" + ) + sys.exit(1) + + # Validate import file exists + if import_result and not os.path.exists(import_result): + log_print(f"Error: Import file does not exist: {import_result}") + sys.exit(1) + # ============================================================================= # SETUP FUNCTIONS @@ -433,6 +464,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]: return losses +def export_losses_to_file(losses: dict[int, float], export_path: str) -> None: + """Export losses to file and stdout. + + Args: + losses: Dictionary mapping step numbers to loss values + export_path: Path to export file + """ + log_print(f"Exporting losses to {export_path}") + + # Write to file and collect output for stdout + with open(export_path, "w") as f: + for step in sorted(losses.keys()): + loss = losses[step] + line = f"{step} {loss}" + f.write(line + "\n") + + log_print(f"Exported {len(losses)} loss values:") + log_print() + + # Output to stdout in same format + for step in sorted(losses.keys()): + loss = losses[step] + print(f"{step} {loss}") + + log_print() + log_print(f"Losses saved to: {export_path}") + + def extract_loss_data(output_folder: str | None) -> None: """Extract loss data from logs.""" if not output_folder: @@ -556,13 +615,18 @@ def perform_loss_analysis( generate_summary_statistics(baseline_losses, test_losses, stats_file) -def assert_losses_equal(baseline_log: str, test_log: str) -> None: - """Assert that losses are equal between baseline and test using - unittest. +def assert_losses_equal( + baseline_log: str, test_log: str, import_result: str | None = None +) -> None: + """Assert that losses are equal between baseline and test using unittest. + + If import_result is provided, also compares baseline with imported losses. """ log_print("Asserting losses are equal...") log_print(f"Baseline log: {baseline_log}") log_print(f"Test log: {test_log}") + if import_result: + log_print(f"Import file: {import_result}") # Extract losses from both logs baseline_losses = extract_losses_from_log(baseline_log) @@ -579,6 +643,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None: log_print("Error: No losses found in test log") sys.exit(1) + # Load imported losses if provided + imported_losses = None + if import_result: + imported_losses = read_losses_from_file(import_result) + log_print(f"Loaded {len(imported_losses)} steps from import file") + if not imported_losses: + log_print("Error: No losses found in import file") + sys.exit(1) + # Create a test case class LossEqualityTest(unittest.TestCase): def test_losses_equal(self): @@ -593,10 +666,22 @@ def test_losses_equal(self): f"test has {len(test_steps)} steps", ) + # If imported losses exist, check steps match + if imported_losses: + imported_steps = set(imported_losses.keys()) + self.assertEqual( + baseline_steps, + imported_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, " + f"imported has {len(imported_steps)} steps", + ) + # Check that losses are equal for each step for step in sorted(baseline_steps): baseline_loss = baseline_losses[step] test_loss = test_losses[step] + + # Compare baseline vs test self.assertEqual( baseline_loss, test_loss, @@ -604,6 +689,18 @@ def test_losses_equal(self): f"baseline={baseline_loss}, test={test_loss}", ) + # Compare baseline vs imported (if provided) + # No need to compare test vs imported since: + # baseline==test and baseline==imported implies test==imported + if imported_losses: + imported_loss = imported_losses[step] + self.assertEqual( + baseline_loss, + imported_loss, + f"Loss mismatch at step {step}: " + f"baseline={baseline_loss}, imported={imported_loss}", + ) + # Run the test suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest) runner = unittest.TextTestRunner(verbosity=2) @@ -613,7 +710,13 @@ def test_losses_equal(self): log_print("Loss assertion failed!") sys.exit(1) else: - log_print("All losses are equal. Assertion passed!") + if import_result: + log_print( + "All losses are equal (baseline, test, and imported). " + "Assertion passed!" + ) + else: + log_print("All losses are equal. Assertion passed!") def cleanup_temp_files(output_folder: str | None) -> None: @@ -756,6 +859,24 @@ def parse_arguments() -> argparse.Namespace: "Script exits with error if losses differ." ), ) + parser.add_argument( + "--export-result", + default="", + help=( + "Export losses to specified file path (requires --assert-equal). " + "Exports only when losses match. Format: '{step} {loss}' per line." + ), + ) + parser.add_argument( + "--import-result", + default="", + help=( + "Import losses from specified file path for comparison " + "(requires --assert-equal). " + "Compares imported losses with both baseline and test " + "(all 3 must match)." + ), + ) parser.add_argument( "--job-dump-folder", default="outputs", @@ -787,6 +908,14 @@ def parse_arguments() -> argparse.Namespace: if not args.output_folder: args.output_folder = None + # Convert empty export_result to None + if not args.export_result: + args.export_result = None + + # Convert empty import_result to None + if not args.import_result: + args.import_result = None + return args @@ -850,6 +979,9 @@ def main() -> None: args.test_train_file, args.test_options, args.steps, + args.assert_equal, + args.export_result, + args.import_result, ) # Setup environment @@ -912,7 +1044,14 @@ def main() -> None: # Assert losses are equal if requested if args.assert_equal: - assert_losses_equal(baseline_log, test_log) + # Pass import_result if provided for 3-way comparison + assert_losses_equal(baseline_log, test_log, args.import_result) + + # Export losses if requested (only after assertion passes) + if args.export_result: + # Extract baseline losses (they equal test losses since assertion passed) + baseline_losses = extract_losses_from_log(baseline_log) + export_losses_to_file(baseline_losses, args.export_result) # Analysis and reporting perform_loss_analysis(baseline_log, test_log, stats_file) diff --git a/tests/assets/losses/llama3.txt b/tests/assets/losses/llama3.txt new file mode 100644 index 0000000000..5ccea64b17 --- /dev/null +++ b/tests/assets/losses/llama3.txt @@ -0,0 +1,10 @@ +1 8.1376 +2 7.841 +3 7.1815 +4 6.3509 +5 5.5272 +6 4.9244 +7 4.5606 +8 4.3724 +9 4.347 +10 4.2004