Skip to content

Commit 51d1ff0

Browse files
committed
fixed laod function
1 parent 6ebce2f commit 51d1ff0

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

pymc_experimental/model_builder.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def save(self, fname):
158158
self.idata.to_netcdf(file)
159159

160160
@classmethod
161-
def load(cls, self, fname):
161+
def load(cls, fname):
162162
"""
163163
Loads inference data for the model.
164164
@@ -185,16 +185,13 @@ def load(cls, self, fname):
185185
"""
186186

187187
filepath = Path(str(fname))
188-
data = az.from_netcdf(filepath)
189-
idata = data
190-
if idata.attrs is not None:
191-
if self.id() == idata.attrs["id"]:
192-
self = cls(idata.attrs["sample_config"], idata.attrs["model_config"])
193-
self.idata = idata
194-
else:
195-
raise ValueError(
196-
f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
197-
)
188+
idata = az.from_netcdf(filepath)
189+
self = cls(dict(zip(idata.attrs['model_config_keys'],idata.attrs['model_config_values'])), dict(zip(idata.attrs['sample_config_keys'],idata.attrs['sample_config_values'])), idata.data)
190+
self.idata=idata
191+
if self.id() != idata.attrs["id"]:
192+
raise ValueError(
193+
f"The route '{fname}' does not contain an inference data of the same model '{self._model_type}'"
194+
)
198195
return self
199196

200197
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
@@ -235,8 +232,11 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
235232
self.idata.attrs["id"] = self.id()
236233
self.idata.attrs["model_type"] = self._model_type
237234
self.idata.attrs["version"] = self.version
238-
self.idata.attrs["sample_config"] = tuple(self.sample_config)
239-
self.idata.attrs["model_config"] = tuple(self.model_config)
235+
self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys())
236+
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
237+
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
238+
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
239+
self.idata.add_groups(data = self.data.to_xarray())
240240
return self.idata
241241

242242
def predict(
@@ -352,5 +352,5 @@ def id(self):
352352
hasher.update(str(self.model_config.values()).encode())
353353
hasher.update(self.version.encode())
354354
hasher.update(self._model_type.encode())
355-
hasher.update(str(self.sample_config.values()).encode())
355+
# hasher.update(str(self.sample_config.values()).encode())
356356
return hasher.hexdigest()[:16]

0 commit comments

Comments
 (0)