1818import pytest
1919
2020import pymc as pm
21+
2122from pymc .model .validation import validate_dims_coords_consistency
2223
2324
@@ -42,7 +43,9 @@ def test_missing_coord_in_sample_raises(self):
4243 pm .Normal ("x" , 0 , 1 , dims = ("time" ,))
4344
4445 with pytest .raises (ValueError , match = "Dimension 'time'.*not defined in model.coords" ):
45- pm .sample (draws = 10 , tune = 10 , chains = 1 , progressbar = False , compute_convergence_checks = False )
46+ pm .sample (
47+ draws = 10 , tune = 10 , chains = 1 , progressbar = False , compute_convergence_checks = False
48+ )
4649
4750 def test_shape_mismatch_raises (self ):
4851 """Test that shape-dims mismatch raises clear error."""
@@ -66,7 +69,9 @@ def test_shape_mismatch_in_sample_raises(self):
6669 pm .Normal ("x" , 0 , 1 , shape = (5 ,), dims = ("time" ,))
6770
6871 with pytest .raises (ValueError , match = "Variable 'x'.*shape.*does not match" ):
69- pm .sample (draws = 10 , tune = 10 , chains = 1 , progressbar = False , compute_convergence_checks = False )
72+ pm .sample (
73+ draws = 10 , tune = 10 , chains = 1 , progressbar = False , compute_convergence_checks = False
74+ )
7075
7176 def test_coord_length_mismatch_raises (self ):
7277 """Test that coord length mismatch raises clear error."""
@@ -234,14 +239,15 @@ def test_complex_model_passes(self):
234239 alpha = pm .Normal ("alpha" , 0 , 1 , dims = ("group" ,))
235240 beta = pm .Normal ("beta" , 0 , 1 , dims = ("time" , "location" ))
236241 gamma = pm .Normal ("gamma" , 0 , 1 )
237-
242+
238243 # Deterministic with dims
239- mu = pm .Deterministic ("mu" , alpha [:, None , None ] + beta , dims = ("group" , "time" , "location" ))
240-
244+ mu = pm .Deterministic (
245+ "mu" , alpha [:, None , None ] + beta , dims = ("group" , "time" , "location" )
246+ )
247+
241248 # Observed data
242249 data = pm .Data ("data" , np .zeros ((3 , 10 , 5 )), dims = ("group" , "time" , "location" ))
243250 pm .Normal ("y" , mu = mu , sigma = 1 , observed = data , dims = ("group" , "time" , "location" ))
244251
245252 # Should pass validation
246253 validate_dims_coords_consistency (model )
247-
0 commit comments