Skip to content

Commit 270caa3

Browse files
committed
More strict in as_xtensor
1 parent 35b1fef commit 270caa3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/xtensor/type.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,12 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
751751

752752
if isinstance(x, Variable):
753753
if isinstance(x.type, XTensorType):
754-
return x
754+
if (dims is None) or (x.type.dims == dims):
755+
return x
756+
else:
757+
raise ValueError(
758+
f"Variable {x} has dims {x.type.dims}, but requested dims are {dims}."
759+
)
755760
if isinstance(x.type, TensorType):
756761
if dims is None:
757762
if x.type.ndim == 0:

0 commit comments

Comments
 (0)