diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1c5c3d7afd58..7c9e4e46a559 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -160,12 +160,19 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Flax schedulers have a state, so return it. + if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index b7de33d2d532..6cfd7ae32112 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -437,8 +437,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, SchedulerMixin): - loaded_sub_model = load_method(loadable_folder) - params[name] = loaded_sub_model.create_state() + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 30c873b45e59..d81d66607147 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4c8c43810b6f..8344505620c4 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self,