Skip to content

Commit 5f9a2e3

Browse files
authored
Pass kwargs from pm.Data to aesara.shared (#5098)
* Tests kwargs are passed from Data to aesara.shared * Pass kwargs from Data to aesara.shared * Document Data kwargs in release notes * Convert Aesara references to intersphinx
1 parent 02b8675 commit 5f9a2e3

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
8888
- New features for BART:
8989
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
9090
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
91+
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
9192
- ...
9293

9394

pymc/data.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def align_minibatches(batches=None):
464464

465465

466466
class Data:
467-
"""Data container class that wraps the Aesara ``SharedVariable`` class
468-
and lets the model be aware of its inputs and outputs.
467+
"""Data container class that wraps :func:`aesara.shared` and lets
468+
the model be aware of its inputs and outputs.
469469
470470
Parameters
471471
----------
@@ -478,10 +478,12 @@ class Data:
478478
random variables). Use this when `value` is a pandas Series or DataFrame. The
479479
`dims` will then be the name of the Series / DataFrame's columns. See ArviZ
480480
documentation for more information about dimensions and coordinates:
481-
https://arviz-devs.github.io/arviz/notebooks/Introduction.html
481+
:ref:`arviz:quickstart`.
482482
export_index_as_coords: bool, optional, default=False
483483
If True, the `Data` container will try to infer what the coordinates should be
484484
if there is an index in `value`.
485+
**kwargs: dict, optional
486+
Extra arguments passed to :func:`aesara.shared`.
485487
486488
Examples
487489
--------
@@ -512,7 +514,15 @@ class Data:
512514
https://docs.pymc.io/notebooks/data_container.html
513515
"""
514516

515-
def __new__(self, name, value, *, dims=None, export_index_as_coords=False):
517+
def __new__(
518+
self,
519+
name,
520+
value,
521+
*,
522+
dims=None,
523+
export_index_as_coords=False,
524+
**kwargs,
525+
):
516526
if isinstance(value, list):
517527
value = np.array(value)
518528

@@ -528,7 +538,7 @@ def __new__(self, name, value, *, dims=None, export_index_as_coords=False):
528538

529539
# `pandas_to_array` takes care of parameter `value` and
530540
# transforms it to something digestible for pymc
531-
shared_object = aesara.shared(pandas_to_array(value), name)
541+
shared_object = aesara.shared(pandas_to_array(value), name, **kwargs)
532542

533543
if isinstance(dims, str):
534544
dims = (dims,)

pymc/tests/test_data_container.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,19 @@ def test_implicit_coords_dataframe(self):
366366
assert "columns" in pmodel.coords
367367
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
368368

369+
def test_data_kwargs(self):
370+
strict_value = True
371+
allow_downcast_value = False
372+
with pm.Model():
373+
data = pm.Data(
374+
"data",
375+
value=[[1.0], [2.0], [3.0]],
376+
strict=strict_value,
377+
allow_downcast=allow_downcast_value,
378+
)
379+
assert data.container.strict is strict_value
380+
assert data.container.allow_downcast is allow_downcast_value
381+
369382

370383
def test_data_naming():
371384
"""

0 commit comments

Comments
 (0)