Skip to content

Commit b658bb9

Browse files
committed
Use new-style typing annotations, not comments
1 parent 6215000 commit b658bb9

File tree

6 files changed

+11
-15
lines changed

6 files changed

+11
-15
lines changed

pymc/backends/arviz.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@
2626
import pymc
2727

2828
from pymc.aesaraf import extract_obs_data
29-
from pymc.model import modelcontext
29+
from pymc.model import Model, modelcontext
3030
from pymc.util import get_default_varnames
3131

3232
if TYPE_CHECKING:
33-
3433
from pymc.backends.base import MultiTrace # pylint: disable=invalid-name
35-
from pymc.model import Model
3634

3735
___all__ = [""]
3836

@@ -144,12 +142,10 @@ def insert(self, k: str, v, idx: int):
144142
class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
145143
"""Encapsulate InferenceData specific logic."""
146144

147-
model = None # type: Optional[Model]
148-
nchains = None # type: int
149-
ndraws = None # type: int
150-
posterior_predictive = None # Type: Optional[Mapping[str, np.ndarray]]
151-
predictions = None # Type: Optional[Mapping[str, np.ndarray]]
152-
prior = None # Type: Optional[Mapping[str, np.ndarray]]
145+
model: Optional[Model] = None
146+
posterior_predictive: Optional[Mapping[str, np.ndarray]] = None
147+
predictions: Optional[Mapping[str, np.ndarray]] = None
148+
prior: Optional[Mapping[str, np.ndarray]] = None
153149

154150
def __init__(
155151
self,

pymc/backends/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
8585
if self._stats is None:
8686
self._stats = []
8787
for sampler in sampler_vars:
88-
data = dict() # type: Dict[str, np.ndarray]
88+
data: Dict[str, np.ndarray] = dict()
8989
self._stats.append(data)
9090
for varname, dtype in sampler.items():
9191
data[varname] = np.zeros(draws, dtype=dtype)

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class Minibatch(TensorVariable):
301301
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
302302
"""
303303

304-
RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]
304+
RNG: Dict[str, List[Any]] = collections.defaultdict(list)
305305

306306
@aesara.config.change_flags(compute_test_value="raise")
307307
def __init__(

pymc/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@
6363

6464
DIST_PARAMETER_TYPES: TypeAlias = Union[np.ndarray, int, float, TensorVariable]
6565

66-
vectorized_ppc = contextvars.ContextVar(
66+
vectorized_ppc: contextvars.ContextVar[Optional[Callable]] = contextvars.ContextVar(
6767
"vectorized_ppc", default=None
68-
) # type: contextvars.ContextVar[Optional[Callable]]
68+
)
6969

7070
PLATFORM = sys.platform
7171

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
188188
on the stack, or ``None``. If ``error_if_none`` is True (default),
189189
raise a ``TypeError`` instead of returning ``None``."""
190190
try:
191-
candidate = cls.get_contexts()[-1] # type: Optional[T]
191+
candidate: Optional[T] = cls.get_contexts()[-1]
192192
except IndexError as e:
193193
# Calling code expects to get a TypeError if the entity
194194
# is unfound, and there's too much to fix.

pymc/variational/opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ def sample(
14821482

14831483
if random_seed is not None:
14841484
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
1485-
samples = self.sample_dict_fn(draws, random_seed=random_seed) # type: dict
1485+
samples: dict = self.sample_dict_fn(draws, random_seed=random_seed)
14861486
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))
14871487

14881488
trace = NDArray(

0 commit comments

Comments
 (0)