@@ -158,7 +158,7 @@ def save(self, fname):
158
158
self .idata .to_netcdf (file )
159
159
160
160
@classmethod
161
- def load (cls , self , fname ):
161
+ def load (cls , fname ):
162
162
"""
163
163
Loads inference data for the model.
164
164
@@ -185,16 +185,13 @@ def load(cls, self, fname):
185
185
"""
186
186
187
187
filepath = Path (str (fname ))
188
- data = az .from_netcdf (filepath )
189
- idata = data
190
- if idata .attrs is not None :
191
- if self .id () == idata .attrs ["id" ]:
192
- self = cls (idata .attrs ["sample_config" ], idata .attrs ["model_config" ])
193
- self .idata = idata
194
- else :
195
- raise ValueError (
196
- f"The route '{ file } ' does not contain an inference data of the same model '{ self .__name__ } '"
197
- )
188
+ idata = az .from_netcdf (filepath )
189
+ self = cls (dict (zip (idata .attrs ['model_config_keys' ],idata .attrs ['model_config_values' ])), dict (zip (idata .attrs ['sample_config_keys' ],idata .attrs ['sample_config_values' ])), idata .data )
190
+ self .idata = idata
191
+ if self .id () != idata .attrs ["id" ]:
192
+ raise ValueError (
193
+ f"The route '{ fname } ' does not contain an inference data of the same model '{ self ._model_type } '"
194
+ )
198
195
return self
199
196
200
197
def fit (self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ):
@@ -235,8 +232,11 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
235
232
self .idata .attrs ["id" ] = self .id ()
236
233
self .idata .attrs ["model_type" ] = self ._model_type
237
234
self .idata .attrs ["version" ] = self .version
238
- self .idata .attrs ["sample_config" ] = tuple (self .sample_config )
239
- self .idata .attrs ["model_config" ] = tuple (self .model_config )
235
+ self .idata .attrs ["sample_config_keys" ] = tuple (self .sample_config .keys ())
236
+ self .idata .attrs ["sample_config_values" ] = tuple (self .sample_config .values ())
237
+ self .idata .attrs ["model_config_keys" ] = tuple (self .model_config .keys ())
238
+ self .idata .attrs ["model_config_values" ] = tuple (self .model_config .values ())
239
+ self .idata .add_groups (data = self .data .to_xarray ())
240
240
return self .idata
241
241
242
242
def predict (
@@ -352,5 +352,5 @@ def id(self):
352
352
hasher .update (str (self .model_config .values ()).encode ())
353
353
hasher .update (self .version .encode ())
354
354
hasher .update (self ._model_type .encode ())
355
- hasher .update (str (self .sample_config .values ()).encode ())
355
+ # hasher.update(str(self.sample_config.values()).encode())
356
356
return hasher .hexdigest ()[:16 ]
0 commit comments