From 9e2055ed5bcf52ec3ed12c0266528912f2b7c0c0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 09:56:21 +0000 Subject: [PATCH 1/4] Optionally return state in from_config. Useful for Flax schedulers. --- src/diffusers/configuration_utils.py | 11 +++++++++-- src/diffusers/schedulers/scheduling_ddim_flax.py | 1 + src/diffusers/schedulers/scheduling_pndm_flax.py | 3 ++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2ab85ecee16d..8b72b6eaa16f 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -157,12 +157,19 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Some components (Flax schedulers) have a state. + if getattr(cls, "has_state", False): # Check for "create_state" in model instead? + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 30c873b45e59..21a070042d72 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -104,6 +104,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ + has_state = True @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4c8c43810b6f..62375cec7eeb 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -112,7 +112,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ - + has_state = True + @register_to_config def __init__( self, From f674320ca2b2f8de808f0a58f60b5b1ace84783e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 06:18:20 +0000 Subject: [PATCH 2/4] has_state is now a property, make check more strict. I don't check the class is `SchedulerMixin` to prevent circular dependencies. It should be enough that the class name starts with "Flax" the object declares it "has_state" and the "create_state" exists too. --- src/diffusers/configuration_utils.py | 4 ++-- src/diffusers/schedulers/scheduling_ddim_flax.py | 4 +++- src/diffusers/schedulers/scheduling_pndm_flax.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 8b72b6eaa16f..d607c5ba8a07 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -161,8 +161,8 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret model = cls(**init_dict) return_tuple = (model,) - # Some components (Flax schedulers) have a state. - if getattr(cls, "has_state", False): # Check for "create_state" in model instead? + # Flax schedulers have a state, so return it. + if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): state = model.create_state() return_tuple += (state,) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 21a070042d72..5e218b7a8e70 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -104,7 +104,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ - has_state = True + @property + def has_state(self): + return True @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 62375cec7eeb..cac516a058fe 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -112,7 +112,9 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ - has_state = True + @property + def has_state(self): + return True @register_to_config def __init__( From 5ab5c4e7745562d49f14506b128ba20c57348bf5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 06:21:52 +0000 Subject: [PATCH 3/4] Use state in pipeline from_pretrained. --- src/diffusers/pipeline_flax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index e65d95a37f3d..888706064963 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -437,8 +437,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, SchedulerMixin): - loaded_sub_model = load_method(loadable_folder) - params[name] = loaded_sub_model.create_state() + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) From b17f51357e2480540881a92cd41096f86ff212b1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 06:46:46 +0000 Subject: [PATCH 4/4] Make style --- src/diffusers/schedulers/scheduling_ddim_flax.py | 1 + src/diffusers/schedulers/scheduling_pndm_flax.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 5e218b7a8e70..d81d66607147 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -104,6 +104,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index cac516a058fe..8344505620c4 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -112,10 +112,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ + @property def has_state(self): return True - + @register_to_config def __init__( self,