Skip to content
Open
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ ______________________________________________________________________
</div>

# Looking for GPUs?
Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.
- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.

Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.

- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.
- [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train.
- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
- [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze.
- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.
- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.

# Installation

Expand Down
72 changes: 63 additions & 9 deletions src/torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,39 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b
return cumulative_gain


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
def _handle_empty_target(action: str, device: torch.device) -> Optional[Tensor]:
"""Return a default nDCG score when the target contains no positive labels.

Args:
action: policy for handling empty targets:
- "skip": return None (exclude from batch average)
- "pos": return a score of 1.0
- "neg": return a score of 0.0
device: the torch device on which to create the output tensor.

Returns:
A scalar tensor with the default score if action is "pos" or "neg".
None if action is "skip".

Raises:
ValueError: if ``action`` is not one of {"skip", "pos", "neg"}.

"""
if action == "skip":
return None
if action == "pos":
return torch.tensor(1.0, device=device)
if action == "neg":
return torch.tensor(0.0, device=device)
raise ValueError(f"Invalid empty_target_action: {action}")


def retrieval_normalized_dcg(
preds: Tensor,
target: Tensor,
top_k: Optional[int] = None,
empty_target_action: str = "skip",
) -> Tensor:
"""Compute `Normalized Discounted Cumulative Gain`_ (for information retrieval).

``preds`` and ``target`` should be of the same shape and live on the same device.
Expand All @@ -79,6 +111,10 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document relevance.
top_k: consider only the top k elements (default: ``None``, which considers them all)
empty_target_action: what to do when the target has no positives:
- "skip": exclude from average
- "pos": assign score 1.0
- "neg": assign score 0.0

Return:
A single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.
Expand All @@ -95,19 +131,37 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
tensor(0.6957)

"""
original_shape = preds.shape
preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True)

top_k = preds.shape[-1] if top_k is None else top_k
# reshape back if input was 2D
if len(original_shape) == 2:
preds = preds.view(original_shape)
target = target.view(original_shape)
else:
preds = preds.unsqueeze(0)
Comment on lines +134 to +142
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 143 should use target.unsqueeze(0) instead of target.view(original_shape). The else branch handles 1D inputs which need to be unsqueezed to match the preds tensor on line 142.

Copilot uses AI. Check for mistakes.

target = target.unsqueeze(0)

n_samples, n_labels = preds.shape
top_k = n_labels if top_k is None else top_k
top_k = min(top_k, n_labels)

if not (isinstance(top_k, int) and top_k > 0):
raise ValueError("`top_k` has to be a positive integer or None")

gain = _dcg_sample_scores(target, preds, top_k, ignore_ties=False)
normalized_gain = _dcg_sample_scores(target, target, top_k, ignore_ties=True)
scores = []
for p, t in zip(preds, target):
gain = _dcg_sample_scores(t, p, top_k, ignore_ties=False)
ideal_gain = _dcg_sample_scores(t, t, top_k, ignore_ties=True)

if ideal_gain == 0:
score = _handle_empty_target(empty_target_action, preds.device)
if score is not None:
scores.append(score)
else:
scores.append(gain / ideal_gain)

# filter undefined scores
all_irrelevant = normalized_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalized_gain[~all_irrelevant]
if not scores:
return torch.tensor(0.0, device=preds.device)

return gain.mean()
return torch.stack(scores).mean()
6 changes: 6 additions & 0 deletions tests/unittests/retrieval/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class _Input(NamedTuple):
),
)

_input_retrieval_scores_2d = _Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, 2)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, 2)),
)

# with errors
_input_retrieval_scores_no_target = _Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
16 changes: 13 additions & 3 deletions tests/unittests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from unittests._helpers import seed_all
from unittests._helpers.testers import Metric, MetricTester
from unittests.retrieval._inputs import _input_retrieval_scores as _irs
from unittests.retrieval._inputs import _input_retrieval_scores_2d as _irs_2d
from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all
from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty
from unittests.retrieval._inputs import _input_retrieval_scores_extra as _irs_extra
Expand Down Expand Up @@ -99,15 +100,20 @@ def _compute_sklearn_metric(
target: Union[Tensor, array],
indexes: Optional[np.ndarray] = None,
metric: Optional[Callable] = None,
empty_target_action: str = "skip",
empty_target_action: Optional[str] = None,
ignore_index: Optional[int] = None,
reverse: bool = False,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
metric_name: Optional[str] = None,
**kwargs: Any,
) -> Tensor:
"""Compute metric with multiple iterations over every query predictions set."""
if indexes is None:
indexes = np.full_like(preds, fill_value=0, dtype=np.int64)
if metric_name == "ndcg" and preds.ndim == 2:
row_indexes = np.arange(preds.shape[0], dtype=np.int64)[:, None]
indexes = np.tile(row_indexes, (1, preds.shape[1]))
else:
indexes = np.zeros_like(preds, dtype=np.int64)
if isinstance(indexes, Tensor):
indexes = indexes.cpu().numpy()
if isinstance(preds, Tensor):
Expand Down Expand Up @@ -393,6 +399,7 @@ def _concat_tests(*tests: tuple[dict]) -> dict:
"argnames": "preds,target",
"argvalues": [
(_irs.preds, _irs.target),
(_irs_2d.preds, _irs_2d.target),
(_irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_int_tgt.preds, _irs_int_tgt.target),
Expand Down Expand Up @@ -494,11 +501,14 @@ def run_functional_metric_test(
metric_functional: Callable,
reference_metric: Callable,
metric_args: dict,
metric_name: Optional[str] = None,
reverse: bool = False,
**kwargs: Any,
):
"""Test functional implementation of metric."""
_ref_metric_adapted = partial(_compute_sklearn_metric, metric=reference_metric, reverse=reverse, **metric_args)
_ref_metric_adapted = partial(
_compute_sklearn_metric, metric=reference_metric, reverse=reverse, metric_name=metric_name, **metric_args
)

super().run_functional_metric_test(
preds=preds,
Expand Down
9 changes: 6 additions & 3 deletions tests/unittests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,20 @@ def test_class_metric_ignore_index(
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target)
@pytest.mark.parametrize("empty_target_action", ["skip", "pos", "neg"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
def test_functional_metric(self, preds: Tensor, target: Tensor, empty_target_action: str, k: int):
"""Test functional implementation of metric."""
metric_args = {"empty_target_action": empty_target_action, "top_k": k}

target = target if target.min() >= 0 else target - target.min()
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_normalized_dcg,
reference_metric=_ndcg_at_k,
metric_args={},
top_k=k,
metric_args=metric_args,
metric_name="ndcg",
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
Expand Down