|
60 | 60 | from pymc.distributions import joint_logpt |
61 | 61 | from pymc.distributions.logprob import _get_scaling |
62 | 62 | from pymc.distributions.transforms import _default_transform |
63 | | -from pymc.exceptions import ImputationWarning, SamplingError, ShapeError |
| 63 | +from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning |
64 | 64 | from pymc.initial_point import make_initial_point_fn |
65 | 65 | from pymc.math import flatten_list |
66 | 66 | from pymc.util import ( |
@@ -1193,14 +1193,31 @@ def set_data( |
1193 | 1193 | f"{new_length}, so new coord values for the {dname} dimension are required." |
1194 | 1194 | ) |
1195 | 1195 | if isinstance(length_tensor, TensorConstant): |
| 1196 | + # The dimension was fixed in length. |
| 1197 | + # Resizing a data variable in this dimension would |
| 1198 | + # definitely lead to shape problems. |
1196 | 1199 | raise ShapeError( |
1197 | 1200 | f"Resizing dimension '{dname}' is impossible, because " |
1198 | 1201 | "a 'TensorConstant' stores its length. To be able " |
1199 | 1202 | "to change the dimension length, pass `mutable=True` when " |
1200 | 1203 | "registering the dimension via `model.add_coord`, " |
1201 | 1204 | "or define it via a `pm.MutableData` variable." |
1202 | 1205 | ) |
| 1206 | + elif isinstance(length_tensor, ScalarSharedVariable): |
| 1207 | + # The dimension is mutable, but was defined without being linked |
| 1208 | + # to a shared variable. This is allowed, but slightly dangerous. |
| 1209 | + warnings.warn( |
| 1210 | + f"You are resizing a variable with dimension '{dname}' which was initialized" |
| 1211 | + " as a mutable dimension and is not linked to the `MutableData` variable." |
| 1212 | + " Remember to update the dimension length by calling " |
| 1213 | + f"`Model.set_dim({dname}, new_length={new_length})` manually," |
| 1214 | + " preferably _before_ updating `MutableData` variables that use this dimension.", |
| 1215 | + ShapeWarning, |
| 1216 | + stacklevel=2, |
| 1217 | + ) |
1203 | 1218 | else: |
| 1219 | + # The dimension was created from another model variable. |
| 1220 | + # If that was a non-mutable variable, there will definitely be shape problems. |
1204 | 1221 | length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0] |
1205 | 1222 | if not isinstance(length_belongs_to, SharedVariable): |
1206 | 1223 | raise ShapeError( |
|
0 commit comments