Skip to content

Conversation

@atheendre130505
Copy link

@atheendre130505 atheendre130505 commented Nov 4, 2025

Description

This PR replaces the deprecated pt.batched_dot function with the preferred pt.sum operation in the KroneckerNormal distribution's logprob calculation, addressing issue #7878.

Problem

The current implementation uses batched_dot, which is deprecated in PyTensor and triggers warnings. Deprecated functions may lead to future breakage and lower performance.

Solution

Refactored the KroneckerNormal logprob code to use pt.sum with explicit axis parameters, achieving the same functionality without relying on deprecated APIs.

Tests

  • All existing tests related to KroneckerNormal pass successfully.
  • No deprecation warnings are shown when running the updated code.
  • Validated numerical equivalence with the previous implementation.

Related Issue

Fixes #7878

Checklist

  • Follows PyMC contributing guidelines
  • All tests pass locally
  • No API-breaking changes

📚 Documentation preview 📚: https://pymc--7951.org.readthedocs.build/en/7951/

- Fixes Issue pymc-devs#7878
- Replace pt.batched_dot(sqrt_quad.T, sqrt_quad.T) with pt.sum(sqrt_quad.T ** 2, axis=-1)
- Computes squared norm per sample using modern PyTensor operations
- Eliminates deprecation warnings and ensures future compatibility
@github-actions github-actions bot added GP Gaussian Process pytensor labels Nov 4, 2025
@atheendre130505
Copy link
Author

pre-commit.ci autofix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

GP Gaussian Process pytensor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KroneckerNormal using deprecated batched_dot

1 participant