@@ -1067,24 +1067,24 @@ def add_coord(
10671067 raise ValueError (
10681068 f"Either `values` or `length` must be specified for the '{ name } ' dimension."
10691069 )
1070- if isinstance (length , int ):
1071- length = at .constant (length )
1072- elif length is not None and not isinstance (length , Variable ):
1073- raise ValueError (
1074- f"The `length` passed for the '{ name } ' coord must be an Aesara Variable or None."
1075- )
10761070 if values is not None :
10771071 # Conversion to a tuple ensures that the coordinate values are immutable.
10781072 # Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
10791073 values = tuple (values )
10801074 if name in self .coords :
10811075 if not np .array_equal (values , self .coords [name ]):
10821076 raise ValueError (f"Duplicate and incompatible coordinate: { name } ." )
1077+ if length is not None and not isinstance (length , (int , Variable )):
1078+ raise ValueError (
1079+ f"The `length` passed for the '{ name } ' coord must be an int, Aesara Variable or None."
1080+ )
10831081 else :
1082+ if length is None :
1083+ length = len (values )
10841084 if mutable :
1085- self ._dim_lengths [name ] = length or aesara .shared (len ( values ) )
1085+ self ._dim_lengths [name ] = aesara .shared (length )
10861086 else :
1087- self ._dim_lengths [name ] = length or aesara .tensor .constant (len ( values ) )
1087+ self ._dim_lengths [name ] = aesara .tensor .constant (length )
10881088 self ._coords [name ] = values
10891089
10901090 def add_coords (
@@ -1101,6 +1101,36 @@ def add_coords(
11011101 for name , values in coords .items ():
11021102 self .add_coord (name , values , length = lengths .get (name , None ))
11031103
1104+ def set_dim (self , name : str , new_length : int , coord_values : Optional [Sequence ] = None ):
1105+ """Resizes a mutable dimension.
1106+
1107+ Parameters
1108+ ----------
1109+ name
1110+ Name of the dimension.
1111+ new_length
1112+ New length of the dimension.
1113+ coord_values
1114+ Optional sequence of coordinate values.
1115+ """
1116+ if not isinstance (self .dim_lengths [name ], ScalarSharedVariable ):
1117+ raise ValueError (f"The dimension '{ name } ' is immutable." )
1118+ if self .coords .get (name , None ) is not None and coord_values is None :
1119+ raise ValueError (
1120+ f"'{ name } ' has coord values. Pass `set_dim(..., coord_values=...)` to update them."
1121+ )
1122+ if coord_values is not None :
1123+ len_cvals = len (coord_values )
1124+ if len_cvals != new_length :
1125+ raise ShapeError (
1126+ f"Length of new coordinate values does not match the new dimension length." ,
1127+ actual = len_cvals ,
1128+ expected = new_length ,
1129+ )
1130+ self ._coords [name ] = tuple (coord_values )
1131+ self .dim_lengths [name ].set_value (new_length )
1132+ return
1133+
11041134 def set_data (
11051135 self ,
11061136 name : str ,
0 commit comments