diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 4e5fea081f..348f231575 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -75,6 +75,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_total_loss_nan, +) from deepmd.utils.path import ( DPH5Path, ) @@ -859,6 +862,9 @@ def log_loss_valid(_task_key="Default"): if not self.multi_task: train_results = log_loss_train(loss, more_loss) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -900,6 +906,11 @@ def log_loss_valid(_task_key="Default"): loss, more_loss, _task_key=_key ) valid_results[_key] = log_loss_valid(_task_key=_key) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse" in train_results[_key]: + check_total_loss_nan( + display_step_id, train_results[_key]["rmse"] + ) if self.rank == 0: log.info( format_training_message_per_task( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 52d2888081..a63be6c57d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -75,6 +75,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_total_loss_nan, +) if torch.__version__.startswith("2"): import torch._dynamo @@ -949,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if not self.multi_task: train_results = log_loss_train(loss, more_loss) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -997,6 +1003,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict: loss, more_loss, _task_key=_key ) valid_results[_key] = log_loss_valid(_task_key=_key) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse" in train_results[_key]: + check_total_loss_nan( + display_step_id, train_results[_key]["rmse"] + ) if self.rank == 0: log.info( format_training_message_per_task( diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index f70c919301..0f26c00171 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -60,6 +60,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_total_loss_nan, +) log = logging.getLogger(__name__) @@ -684,6 +687,11 @@ def valid_on_the_fly( cur_batch = self.cur_batch current_lr = run_sess(self.sess, self.learning_rate) + + # Check for NaN in total loss before writing to file and saving checkpoint + # We check the main total loss component that represents training loss + check_total_loss_nan(cur_batch, train_results["rmse"]) + if print_header: self.print_header(fp, train_results, valid_results) self.print_on_training( diff --git a/deepmd/utils/nan_detector.py b/deepmd/utils/nan_detector.py new file mode 100644 index 0000000000..7c5095322f --- /dev/null +++ b/deepmd/utils/nan_detector.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Utilities for detecting NaN values in loss during training.""" + +import logging +import math + +log = logging.getLogger(__name__) + + +class LossNaNError(RuntimeError): + """Exception raised when NaN is detected in total loss during training.""" + + def __init__(self, step: int, total_loss: float) -> None: + """Initialize the exception. + + Parameters + ---------- + step : int + The training step where NaN was detected + total_loss : float + The total loss value that contains NaN + """ + self.step = step + self.total_loss = total_loss + message = ( + f"NaN detected in total loss at training step {step}: {total_loss}. " + f"Training stopped to prevent wasting time with corrupted parameters. " + f"This typically indicates unstable training conditions such as " + f"learning rate too high, poor data quality, or numerical instability." + ) + super().__init__(message) + + +def check_total_loss_nan(step: int, total_loss: float) -> None: + """Check if the total loss contains NaN and raise an exception if found. + + This function is designed to be called during training after the total loss + is computed and converted to a CPU float value. + + Parameters + ---------- + step : int + Current training step + total_loss : float + Total loss value to check for NaN + + Raises + ------ + LossNaNError + If the total loss contains NaN + """ + if math.isnan(total_loss): + log.error(f"NaN detected in total loss at step {step}: {total_loss}") + raise LossNaNError(step, total_loss) diff --git a/source/tests/common/test_nan_detector.py b/source/tests/common/test_nan_detector.py new file mode 100644 index 0000000000..250f9205fb --- /dev/null +++ b/source/tests/common/test_nan_detector.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test cases for NaN detection utility.""" + +import math +import unittest + +import numpy as np + +from deepmd.utils.nan_detector import ( + LossNaNError, + check_total_loss_nan, +) + + +class TestNaNDetector(unittest.TestCase): + """Test the NaN detection utility functions.""" + + def test_normal_values_pass(self): + """Test that normal loss values don't trigger NaN detection.""" + # Test with various normal values + normal_losses = [0.5, 1.0, 0.001, 0.0, -0.5] + + # Should not raise any exception + for i, loss_val in enumerate(normal_losses): + check_total_loss_nan(100 + i, loss_val) + + def test_nan_detection_raises_exception(self): + """Test that NaN values trigger the proper exception.""" + # Test with NaN value + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(200, float("nan")) + + exception = context.exception + self.assertEqual(exception.step, 200) + self.assertTrue(math.isnan(exception.total_loss)) + self.assertIn("NaN detected in total loss at training step 200", str(exception)) + + def test_various_nan_representations(self): + """Test detection of various NaN representations.""" + nan_values = [ + float("nan"), + np.nan, + math.nan, + ] + + for i, nan_val in enumerate(nan_values): + with self.assertRaises(LossNaNError): + check_total_loss_nan(i, nan_val) + + def test_error_message_format(self): + """Test that error messages contain useful information.""" + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(123, float("nan")) + + error_msg = str(context.exception) + + # Check key information is in the message + self.assertIn("step 123", error_msg) + self.assertIn("Training stopped", error_msg) + self.assertIn("learning rate too high", error_msg) + + def test_edge_cases(self): + """Test edge cases for NaN detection.""" + # Infinity should not trigger NaN detection (separate issue) + try: + check_total_loss_nan(1, float("inf")) + check_total_loss_nan(2, float("-inf")) + except Exception as e: + self.fail(f"Infinity should not raise NaN exception: {e}") + + def test_numeric_types(self): + """Test that various numeric types work correctly.""" + # Various numeric types that should pass + test_values = [ + 0.5, # float + 1, # int + np.float32(0.3), # NumPy float32 + np.float64(0.7), # NumPy float64 + ] + + for i, val in enumerate(test_values): + try: + check_total_loss_nan(10 + i, float(val)) + except Exception as e: + self.fail(f"Numeric type {type(val)} should not raise exception: {e}") + + def test_inheritance_from_runtime_error(self): + """Test that LossNaNError inherits from RuntimeError.""" + self.assertTrue(issubclass(LossNaNError, RuntimeError)) + + try: + check_total_loss_nan(999, float("nan")) + except LossNaNError as e: + self.assertIsInstance(e, RuntimeError) + except Exception: + self.fail("Should raise LossNaNError which inherits from RuntimeError") + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/test_nan_integration.py b/source/tests/common/test_nan_integration.py new file mode 100644 index 0000000000..6a754d93f4 --- /dev/null +++ b/source/tests/common/test_nan_integration.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Integration test to verify NaN detection during training. + +This test creates a mock training scenario where total loss becomes NaN +and verifies that the training stops with appropriate error message. +""" + +import unittest +from unittest.mock import ( + patch, +) + +from deepmd.utils.nan_detector import ( + LossNaNError, + check_total_loss_nan, +) + + +class TestNaNDetectionIntegration(unittest.TestCase): + """Integration tests for NaN detection during training.""" + + def test_training_stops_on_nan_loss(self): + """Test that training stops when NaN is detected in total loss.""" + # Normal total loss should pass + try: + check_total_loss_nan(100, 0.1) + except Exception as e: + self.fail(f"Normal total loss should not raise exception: {e}") + + # NaN total loss should raise + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(100, float("nan")) + + exception = context.exception + self.assertEqual(exception.step, 100) + self.assertIn("NaN detected in total loss", str(exception)) + + @patch("deepmd.utils.nan_detector.log") + def test_logging_on_nan_detection(self, mock_log): + """Test that NaN detection logs appropriate error messages.""" + with self.assertRaises(LossNaNError): + check_total_loss_nan(200, float("nan")) + + # Verify that error was logged + mock_log.error.assert_called_once() + logged_message = mock_log.error.call_args[0][0] + self.assertIn("NaN detected in total loss at step 200", logged_message) + + def test_training_simulation_with_checkpoint_prevention(self): + """Simulate the training checkpoint scenario to ensure NaN prevents saving.""" + # Simulate the training flow: check total loss, then save checkpoint + step_id = 1000 + total_loss = float("nan") + + # This should raise LossNaNError, preventing any subsequent checkpoint saving + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(step_id, total_loss) + + # Verify the error contains expected information + exception = context.exception + self.assertIn("Training stopped to prevent wasting time", str(exception)) + self.assertIn("corrupted parameters", str(exception)) + + def test_realistic_training_scenario(self): + """Test a more realistic training scenario with decreasing then NaN loss.""" + # Simulate normal training progression + normal_steps = [ + (1, 1.0), # Initial high loss + (10, 0.5), # Loss decreasing + (20, 0.25), # Loss continuing to decrease + (50, 0.1), # Good progress + ] + + # All normal steps should pass + for step, loss_val in normal_steps: + try: + check_total_loss_nan(step, loss_val) + except Exception as e: + self.fail( + f"Normal training step {step} should not raise exception: {e}" + ) + + # But when loss becomes NaN, training should stop + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(100, float("nan")) + + exception = context.exception + self.assertEqual(exception.step, 100) + self.assertIn("Training stopped", str(exception)) + + +if __name__ == "__main__": + unittest.main()