@@ -1067,25 +1067,26 @@ def add_coord(
10671067 raise ValueError (
10681068 f"Either `values` or `length` must be specified for the '{ name } ' dimension."
10691069 )
1070- if isinstance (length , int ):
1071- length = at .constant (length )
1072- elif length is not None and not isinstance (length , Variable ):
1073- raise ValueError (
1074- f"The `length` passed for the '{ name } ' coord must be an Aesara Variable or None."
1075- )
10761070 if values is not None :
10771071 # Conversion to a tuple ensures that the coordinate values are immutable.
10781072 # Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
10791073 values = tuple (values )
10801074 if name in self .coords :
10811075 if not np .array_equal (values , self .coords [name ]):
10821076 raise ValueError (f"Duplicate and incompatible coordinate: { name } ." )
1083- else :
1077+ if length is not None and not isinstance (length , (int , Variable )):
1078+ raise ValueError (
1079+ f"The `length` passed for the '{ name } ' coord must be an int, Aesara Variable or None."
1080+ )
1081+ if length is None :
1082+ length = len (values )
1083+ if not isinstance (length , Variable ):
10841084 if mutable :
1085- self . _dim_lengths [ name ] = length or aesara .shared (len ( values ) )
1085+ length = aesara .shared (length )
10861086 else :
1087- self ._dim_lengths [name ] = length or aesara .tensor .constant (len (values ))
1088- self ._coords [name ] = values
1087+ length = aesara .tensor .constant (length )
1088+ self ._dim_lengths [name ] = length
1089+ self ._coords [name ] = values
10891090
10901091 def add_coords (
10911092 self ,
@@ -1101,6 +1102,36 @@ def add_coords(
11011102 for name , values in coords .items ():
11021103 self .add_coord (name , values , length = lengths .get (name , None ))
11031104
1105+ def set_dim (self , name : str , new_length : int , coord_values : Optional [Sequence ] = None ):
1106+ """Resizes a mutable dimension.
1107+
1108+ Parameters
1109+ ----------
1110+ name
1111+ Name of the dimension.
1112+ new_length
1113+ New length of the dimension.
1114+ coord_values
1115+ Optional sequence of coordinate values.
1116+ """
1117+ if not isinstance (self .dim_lengths [name ], ScalarSharedVariable ):
1118+ raise ValueError (f"The dimension '{ name } ' is immutable." )
1119+ if self .coords .get (name , None ) is not None and coord_values is None :
1120+ raise ValueError (
1121+ f"'{ name } ' has coord values. Pass `set_dim(..., coord_values=...)` to update them."
1122+ )
1123+ if coord_values is not None :
1124+ len_cvals = len (coord_values )
1125+ if len_cvals != new_length :
1126+ raise ShapeError (
1127+ f"Length of new coordinate values does not match the new dimension length." ,
1128+ actual = len_cvals ,
1129+ expected = new_length ,
1130+ )
1131+ self ._coords [name ] = tuple (coord_values )
1132+ self .dim_lengths [name ].set_value (new_length )
1133+ return
1134+
11041135 def set_data (
11051136 self ,
11061137 name : str ,
0 commit comments