Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -778,6 +781,8 @@ def step(_step_id, task_key="Default") -> None:
label=label_dict,
task_key=task_key,
)
# Check for NaN in total loss before backward pass to prevent corrupted training
check_total_loss_nan(_step_id + 1, loss.item())

with nvprof_context(enable_profiling, "Backward pass"):
loss.backward()
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -761,6 +764,8 @@ def step(_step_id: int, task_key: str = "Default") -> None:
model_pred, loss, more_loss = self.wrapper(
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
)
# Check for NaN in total loss before backward pass to prevent corrupted training
check_total_loss_nan(_step_id + 1, loss.item())
loss.backward()
if self.gradient_max_norm > 0.0:
torch.nn.utils.clip_grad_norm_(
Expand Down Expand Up @@ -812,6 +817,8 @@ def fake_model() -> dict:
int(input_dict["atype"].shape[-1]),
learning_rate=pref_lr,
)
# Check for NaN in total loss before continuing training
check_total_loss_nan(_step_id + 1, loss.item())
elif isinstance(self.loss, DenoiseLoss):
KFOptWrapper = KFOptimizerWrapper(
self.wrapper,
Expand All @@ -838,6 +845,8 @@ def fake_model() -> dict:
input_dict["natoms"],
learning_rate=pref_lr,
)
# Check for NaN in total loss before continuing training
check_total_loss_nan(_step_id + 1, loss.item())
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

Expand Down
18 changes: 18 additions & 0 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -684,6 +687,21 @@ def valid_on_the_fly(

cur_batch = self.cur_batch
current_lr = run_sess(self.sess, self.learning_rate)

# Check for NaN in total loss before writing to file and saving checkpoint
# We check the main loss component that represents total training loss
if train_results:
# Look for the main loss key (typically the first loss component)
main_loss_key = next(iter(train_results.keys())) if train_results else None
if main_loss_key and main_loss_key in train_results:
check_total_loss_nan(cur_batch, train_results[main_loss_key])

if valid_results:
# Check validation loss as well for consistency
main_loss_key = next(iter(valid_results.keys())) if valid_results else None
if main_loss_key and main_loss_key in valid_results:
check_total_loss_nan(cur_batch, valid_results[main_loss_key])

if print_header:
self.print_header(fp, train_results, valid_results)
self.print_on_training(
Expand Down
54 changes: 54 additions & 0 deletions deepmd/utils/nan_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for detecting NaN values in loss during training."""

import logging
import math

log = logging.getLogger(__name__)


class LossNaNError(RuntimeError):
"""Exception raised when NaN is detected in total loss during training."""

def __init__(self, step: int, total_loss: float) -> None:
"""Initialize the exception.

Parameters
----------
step : int
The training step where NaN was detected
total_loss : float
The total loss value that contains NaN
"""
self.step = step
self.total_loss = total_loss
message = (
f"NaN detected in total loss at training step {step}: {total_loss}. "
f"Training stopped to prevent wasting time with corrupted parameters. "
f"This typically indicates unstable training conditions such as "
f"learning rate too high, poor data quality, or numerical instability."
)
super().__init__(message)


def check_total_loss_nan(step: int, total_loss: float) -> None:
"""Check if the total loss contains NaN and raise an exception if found.

This function is designed to be called during training after the total loss
is computed and converted to a CPU float value.

Parameters
----------
step : int
Current training step
total_loss : float
Total loss value to check for NaN

Raises
------
LossNaNError
If the total loss contains NaN
"""
if math.isnan(total_loss):
log.error(f"NaN detected in total loss at step {step}: {total_loss}")
raise LossNaNError(step, total_loss)
103 changes: 103 additions & 0 deletions source/tests/common/test_nan_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test cases for NaN detection utility."""

import math
import unittest

import numpy as np

from deepmd.utils.nan_detector import (
LossNaNError,
check_total_loss_nan,
)


class TestNaNDetector(unittest.TestCase):
"""Test the NaN detection utility functions."""

def test_normal_values_pass(self):
"""Test that normal loss values don't trigger NaN detection."""
# Test with various normal values
normal_losses = [0.5, 1.0, 0.001, 0.0, -0.5]

# Should not raise any exception
for i, loss_val in enumerate(normal_losses):
try:
check_total_loss_nan(100 + i, loss_val)
except Exception as e:
self.fail(f"Normal values should not raise exception: {e}")

def test_nan_detection_raises_exception(self):
"""Test that NaN values trigger the proper exception."""
# Test with NaN value
with self.assertRaises(LossNaNError) as context:
check_total_loss_nan(200, float("nan"))

exception = context.exception
self.assertEqual(exception.step, 200)
self.assertTrue(math.isnan(exception.total_loss))
self.assertIn("NaN detected in total loss at training step 200", str(exception))

def test_various_nan_representations(self):
"""Test detection of various NaN representations."""
nan_values = [
float("nan"),
np.nan,
math.nan,
]

for i, nan_val in enumerate(nan_values):
with self.assertRaises(LossNaNError):
check_total_loss_nan(i, nan_val)

def test_error_message_format(self):
"""Test that error messages contain useful information."""
with self.assertRaises(LossNaNError) as context:
check_total_loss_nan(123, float("nan"))

error_msg = str(context.exception)

# Check key information is in the message
self.assertIn("step 123", error_msg)
self.assertIn("Training stopped", error_msg)
self.assertIn("learning rate too high", error_msg)

def test_edge_cases(self):
"""Test edge cases for NaN detection."""
# Infinity should not trigger NaN detection (separate issue)
try:
check_total_loss_nan(1, float("inf"))
check_total_loss_nan(2, float("-inf"))
except Exception as e:
self.fail(f"Infinity should not raise NaN exception: {e}")

def test_numeric_types(self):
"""Test that various numeric types work correctly."""
# Various numeric types that should pass
test_values = [
0.5, # float
1, # int
np.float32(0.3), # NumPy float32
np.float64(0.7), # NumPy float64
]

for i, val in enumerate(test_values):
try:
check_total_loss_nan(10 + i, float(val))
except Exception as e:
self.fail(f"Numeric type {type(val)} should not raise exception: {e}")

def test_inheritance_from_runtime_error(self):
"""Test that LossNaNError inherits from RuntimeError."""
self.assertTrue(issubclass(LossNaNError, RuntimeError))

try:
check_total_loss_nan(999, float("nan"))
except LossNaNError as e:
self.assertIsInstance(e, RuntimeError)
except Exception:
self.fail("Should raise LossNaNError which inherits from RuntimeError")


if __name__ == "__main__":
unittest.main()
102 changes: 102 additions & 0 deletions source/tests/common/test_nan_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Integration test to verify NaN detection during training.
This test creates a mock training scenario where total loss becomes NaN
and verifies that the training stops with appropriate error message.
"""

import unittest
from unittest.mock import (
patch,
)

from deepmd.utils.nan_detector import (
LossNaNError,
check_total_loss_nan,
)


class TestNaNDetectionIntegration(unittest.TestCase):
"""Integration tests for NaN detection during training."""

def test_training_stops_on_nan_loss(self):
"""Test that training stops when NaN is detected in total loss."""
# Normal total loss should pass
try:
check_total_loss_nan(100, 0.1)
except Exception as e:
self.fail(f"Normal total loss should not raise exception: {e}")

# NaN total loss should raise
with self.assertRaises(LossNaNError) as context:
check_total_loss_nan(100, float("nan"))

exception = context.exception
self.assertEqual(exception.step, 100)
self.assertIn("NaN detected in total loss", str(exception))

@patch("deepmd.utils.nan_detector.log")
def test_logging_on_nan_detection(self, mock_log):
"""Test that NaN detection logs appropriate error messages."""
with self.assertRaises(LossNaNError):
check_total_loss_nan(200, float("nan"))

# Verify that error was logged
mock_log.error.assert_called_once()
logged_message = mock_log.error.call_args[0][0]
self.assertIn("NaN detected in total loss at step 200", logged_message)

def test_training_simulation_with_checkpoint_prevention(self):
"""Simulate the training checkpoint scenario to ensure NaN prevents saving."""

def mock_save_checkpoint():
"""Mock function that should not be called when NaN is detected."""
raise AssertionError("Checkpoint should not be saved when NaN is detected!")

# Simulate the training flow: check total loss, then save checkpoint
step_id = 1000
total_loss = float("nan")

# This should raise LossNaNError before checkpoint saving
with self.assertRaises(LossNaNError):
check_total_loss_nan(step_id, total_loss)
# This line should never be reached
mock_save_checkpoint()

# Verify the error contains expected information
try:
check_total_loss_nan(step_id, total_loss)
except LossNaNError as e:
self.assertIn("Training stopped to prevent wasting time", str(e))
self.assertIn("corrupted parameters", str(e))

def test_realistic_training_scenario(self):
"""Test a more realistic training scenario with decreasing then NaN loss."""
# Simulate normal training progression
normal_steps = [
(1, 1.0), # Initial high loss
(10, 0.5), # Loss decreasing
(20, 0.25), # Loss continuing to decrease
(50, 0.1), # Good progress
]

# All normal steps should pass
for step, loss_val in normal_steps:
try:
check_total_loss_nan(step, loss_val)
except Exception as e:
self.fail(
f"Normal training step {step} should not raise exception: {e}"
)

# But when loss becomes NaN, training should stop
with self.assertRaises(LossNaNError) as context:
check_total_loss_nan(100, float("nan"))

exception = context.exception
self.assertEqual(exception.step, 100)
self.assertIn("Training stopped", str(exception))


if __name__ == "__main__":
unittest.main()