Skip to content

Commit 688031c

Browse files
authored
Fix import with Flax but without PyTorch (#688)
* Don't use `load_state_dict` if torch is not installed. * Define `SchedulerOutput` to use torch or flax arrays. * Don't import LMSDiscreteScheduler without torch. * Create distinct FlaxSchedulerOutput. * Additional changes required for FlaxSchedulerMixin * Do not import torch pipelines in Flax. * Revert "Define `SchedulerOutput` to use torch or flax arrays." This reverts commit f653140. * Prefix Flax scheduler outputs for consistency. * make style * FlaxSchedulerOutput is now a dataclass. * Don't use f-string without placeholders. * Add blank line. * Style (docstrings)
1 parent 7d0ba59 commit 688031c

13 files changed

+131
-63
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
FlaxKarrasVeScheduler,
7474
FlaxLMSDiscreteScheduler,
7575
FlaxPNDMScheduler,
76+
FlaxSchedulerMixin,
7677
FlaxScoreSdeVeScheduler,
7778
)
7879
else:

src/diffusers/modeling_flax_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2828
from requests import HTTPError
2929

30+
from . import is_torch_available
3031
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
31-
from .modeling_utils import load_state_dict
3232
from .utils import (
3333
CONFIG_NAME,
3434
DIFFUSERS_CACHE,
@@ -391,6 +391,14 @@ def from_pretrained(
391391
)
392392

393393
if from_pt:
394+
if is_torch_available():
395+
from .modeling_utils import load_state_dict
396+
else:
397+
raise EnvironmentError(
398+
"Can't load the model in PyTorch format because PyTorch is not installed. "
399+
"Please, install PyTorch or use native Flax weights."
400+
)
401+
394402
# Step 1: Get the pytorch file
395403
pytorch_model_file = load_state_dict(model_file)
396404

src/diffusers/pipeline_flax_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from .configuration_utils import ConfigMixin
3232
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
33-
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
33+
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
3434
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
3535

3636

@@ -46,7 +46,7 @@
4646
LOADABLE_CLASSES = {
4747
"diffusers": {
4848
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
49-
"SchedulerMixin": ["save_config", "from_config"],
49+
"FlaxSchedulerMixin": ["save_config", "from_config"],
5050
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
5151
},
5252
"transformers": {
@@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
436436
else:
437437
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
438438
params[name] = loaded_params
439-
elif issubclass(class_obj, SchedulerMixin):
439+
elif issubclass(class_obj, FlaxSchedulerMixin):
440440
loaded_sub_model, scheduler_state = load_method(loadable_folder)
441441
params[name] = scheduler_state
442442
else:

src/diffusers/pipelines/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
2-
from .ddim import DDIMPipeline
3-
from .ddpm import DDPMPipeline
4-
from .latent_diffusion_uncond import LDMPipeline
5-
from .pndm import PNDMPipeline
6-
from .score_sde_ve import ScoreSdeVePipeline
7-
from .stochastic_karras_ve import KarrasVePipeline
82

93

4+
if is_torch_available():
5+
from .ddim import DDIMPipeline
6+
from .ddpm import DDPMPipeline
7+
from .latent_diffusion_uncond import LDMPipeline
8+
from .pndm import PNDMPipeline
9+
from .score_sde_ve import ScoreSdeVePipeline
10+
from .stochastic_karras_ve import KarrasVePipeline
11+
else:
12+
from ..utils.dummy_pt_objects import * # noqa F403
13+
1014
if is_torch_available() and is_transformers_available():
1115
from .latent_diffusion import LDMTextToImagePipeline
1216
from .stable_diffusion import (

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL
77
from PIL import Image
88

9-
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
9+
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
1010

1111

1212
@dataclass
@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
2727
nsfw_content_detected: List[bool]
2828

2929

30-
if is_transformers_available():
30+
if is_transformers_available() and is_torch_available():
3131
from .pipeline_stable_diffusion import StableDiffusionPipeline
3232
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
3333
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline

src/diffusers/schedulers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
3535
from .scheduling_pndm_flax import FlaxPNDMScheduler
3636
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
37+
from .scheduling_utils_flax import FlaxSchedulerMixin
3738
else:
3839
from ..utils.dummy_flax_objects import * # noqa F403
3940

40-
if is_scipy_available():
41+
42+
if is_scipy_available() and is_torch_available():
4143
from .scheduling_lms_discrete import LMSDiscreteScheduler
4244
else:
4345
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax.numpy as jnp
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -68,11 +68,11 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
6868

6969

7070
@dataclass
71-
class FlaxSchedulerOutput(SchedulerOutput):
71+
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
7272
state: DDIMSchedulerState
7373

7474

75-
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
75+
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
7676
"""
7777
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
7878
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
@@ -183,7 +183,7 @@ def step(
183183
timestep: int,
184184
sample: jnp.ndarray,
185185
return_dict: bool = True,
186-
) -> Union[FlaxSchedulerOutput, Tuple]:
186+
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
187187
"""
188188
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
189189
process from the learned model outputs (most often the predicted noise).
@@ -197,11 +197,11 @@ def step(
197197
key (`random.KeyArray`): a PRNG key.
198198
eta (`float`): weight of noise for added noise in diffusion step.
199199
use_clipped_model_output (`bool`): TODO
200-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
200+
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
201201
202202
Returns:
203-
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
204-
When returning a tuple, the first element is the sample tensor.
203+
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
204+
`tuple`. When returning a tuple, the first element is the sample tensor.
205205
206206
"""
207207
if state.num_inference_steps is None:
@@ -252,7 +252,7 @@ def step(
252252
if not return_dict:
253253
return (prev_sample, state)
254254

255-
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
255+
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
256256

257257
def add_noise(
258258
self,

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jax import random
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -67,11 +67,11 @@ def create(cls, num_train_timesteps: int):
6767

6868

6969
@dataclass
70-
class FlaxSchedulerOutput(SchedulerOutput):
70+
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
7171
state: DDPMSchedulerState
7272

7373

74-
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
74+
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
7575
"""
7676
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
7777
Langevin dynamics sampling.
@@ -191,7 +191,7 @@ def step(
191191
key: random.KeyArray,
192192
predict_epsilon: bool = True,
193193
return_dict: bool = True,
194-
) -> Union[FlaxSchedulerOutput, Tuple]:
194+
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
195195
"""
196196
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
197197
process from the learned model outputs (most often the predicted noise).
@@ -205,11 +205,11 @@ def step(
205205
key (`random.KeyArray`): a PRNG key.
206206
predict_epsilon (`bool`):
207207
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
208-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
208+
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
209209
210210
Returns:
211-
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
212-
When returning a tuple, the first element is the sample tensor.
211+
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
212+
`tuple`. When returning a tuple, the first element is the sample tensor.
213213
214214
"""
215215
t = timestep
@@ -257,7 +257,7 @@ def step(
257257
if not return_dict:
258258
return (pred_prev_sample, state)
259259

260-
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
260+
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
261261

262262
def add_noise(
263263
self,

src/diffusers/schedulers/scheduling_karras_ve_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
2424
from ..utils import BaseOutput
25-
from .scheduling_utils import SchedulerMixin
25+
from .scheduling_utils_flax import FlaxSchedulerMixin
2626

2727

2828
@flax.struct.dataclass
@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
5656
state: KarrasVeSchedulerState
5757

5858

59-
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
59+
class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
6060
"""
6161
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
6262
the VE column of Table 1 from [1] for reference.
@@ -172,7 +172,7 @@ def step(
172172
sigma_hat (`float`): TODO
173173
sigma_prev (`float`): TODO
174174
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
175-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
175+
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
176176
177177
Returns:
178178
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
@@ -211,7 +211,7 @@ def step_correct(
211211
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
212212
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
213213
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
214-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
214+
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
215215
216216
Returns:
217217
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from scipy import integrate
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
23+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
2424

2525

2626
@flax.struct.dataclass
@@ -37,11 +37,11 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
3737

3838

3939
@dataclass
40-
class FlaxSchedulerOutput(SchedulerOutput):
40+
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
4141
state: LMSDiscreteSchedulerState
4242

4343

44-
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
44+
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
4545
"""
4646
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
4747
Katherine Crowson:
@@ -147,7 +147,7 @@ def step(
147147
sample: jnp.ndarray,
148148
order: int = 4,
149149
return_dict: bool = True,
150-
) -> Union[SchedulerOutput, Tuple]:
150+
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
151151
"""
152152
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
153153
process from the learned model outputs (most often the predicted noise).
@@ -159,11 +159,11 @@ def step(
159159
sample (`jnp.ndarray`):
160160
current instance of sample being created by diffusion process.
161161
order: coefficient for multi-step inference.
162-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
162+
return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
163163
164164
Returns:
165-
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
166-
When returning a tuple, the first element is the sample tensor.
165+
[`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
166+
`tuple`. When returning a tuple, the first element is the sample tensor.
167167
168168
"""
169169
sigma = state.sigmas[timestep]
@@ -189,7 +189,7 @@ def step(
189189
if not return_dict:
190190
return (prev_sample, state)
191191

192-
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
192+
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
193193

194194
def add_noise(
195195
self,

0 commit comments

Comments
 (0)