Skip to content

Commit 7a188b8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 001d418 commit 7a188b8

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

pymc/model/validation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
import numpy as np
2424
import pytensor.tensor as pt
25+
2526
from pytensor.graph.basic import Variable
26-
from pytensor.tensor.variable import TensorVariable, TensorConstant
2727

2828
try:
2929
unused = TYPE_CHECKING
@@ -238,7 +238,9 @@ def check_shape_dims_match(model: Model) -> list[str]:
238238
except Exception:
239239
evaluated_shape.append(None) # Can't validate
240240
else:
241-
evaluated_shape.append(int(shape_elem) if shape_elem is not None else None)
241+
evaluated_shape.append(
242+
int(shape_elem) if shape_elem is not None else None
243+
)
242244
shape_idx += 1
243245

244246
# Compare only elements we could evaluate
@@ -256,8 +258,7 @@ def check_shape_dims_match(model: Model) -> list[str]:
256258
if mismatches:
257259
errors.append(
258260
f"Variable '{var_name}' declares dims {dims} but its shape "
259-
f"does not match the coordinate lengths:\n"
260-
+ "\n".join(mismatches)
261+
f"does not match the coordinate lengths:\n" + "\n".join(mismatches)
261262
)
262263
except Exception:
263264
# If we can't evaluate the shape, skip this check
@@ -324,7 +325,9 @@ def check_coord_lengths(model: Model) -> list[str]:
324325
if dims is not None and dim_name in dims:
325326
using_vars.append(var_name)
326327

327-
var_list = ", ".join([f"'{v}'" for v in sorted(using_vars)]) if using_vars else "variables"
328+
var_list = (
329+
", ".join([f"'{v}'" for v in sorted(using_vars)]) if using_vars else "variables"
330+
)
328331

329332
errors.append(
330333
f"Dimension '{dim_name}' has coordinate values of length {coord_length}, "
@@ -338,4 +341,3 @@ def check_coord_lengths(model: Model) -> list[str]:
338341
pass
339342

340343
return errors
341-

tests/model/test_dims_coords_validation.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
import pymc as pm
21+
2122
from pymc.model.validation import validate_dims_coords_consistency
2223

2324

@@ -42,7 +43,9 @@ def test_missing_coord_in_sample_raises(self):
4243
pm.Normal("x", 0, 1, dims=("time",))
4344

4445
with pytest.raises(ValueError, match="Dimension 'time'.*not defined in model.coords"):
45-
pm.sample(draws=10, tune=10, chains=1, progressbar=False, compute_convergence_checks=False)
46+
pm.sample(
47+
draws=10, tune=10, chains=1, progressbar=False, compute_convergence_checks=False
48+
)
4649

4750
def test_shape_mismatch_raises(self):
4851
"""Test that shape-dims mismatch raises clear error."""
@@ -66,7 +69,9 @@ def test_shape_mismatch_in_sample_raises(self):
6669
pm.Normal("x", 0, 1, shape=(5,), dims=("time",))
6770

6871
with pytest.raises(ValueError, match="Variable 'x'.*shape.*does not match"):
69-
pm.sample(draws=10, tune=10, chains=1, progressbar=False, compute_convergence_checks=False)
72+
pm.sample(
73+
draws=10, tune=10, chains=1, progressbar=False, compute_convergence_checks=False
74+
)
7075

7176
def test_coord_length_mismatch_raises(self):
7277
"""Test that coord length mismatch raises clear error."""
@@ -234,14 +239,15 @@ def test_complex_model_passes(self):
234239
alpha = pm.Normal("alpha", 0, 1, dims=("group",))
235240
beta = pm.Normal("beta", 0, 1, dims=("time", "location"))
236241
gamma = pm.Normal("gamma", 0, 1)
237-
242+
238243
# Deterministic with dims
239-
mu = pm.Deterministic("mu", alpha[:, None, None] + beta, dims=("group", "time", "location"))
240-
244+
mu = pm.Deterministic(
245+
"mu", alpha[:, None, None] + beta, dims=("group", "time", "location")
246+
)
247+
241248
# Observed data
242249
data = pm.Data("data", np.zeros((3, 10, 5)), dims=("group", "time", "location"))
243250
pm.Normal("y", mu=mu, sigma=1, observed=data, dims=("group", "time", "location"))
244251

245252
# Should pass validation
246253
validate_dims_coords_consistency(model)
247-

0 commit comments

Comments
 (0)