Skip to content

Commit c6357ed

Browse files
committed
[Flax] Stateless schedulers, fixes and refactors
1 parent 784beee commit c6357ed

13 files changed

+652
-550
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def collate_fn(examples):
477477
noise_scheduler = FlaxDDPMScheduler(
478478
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
479479
)
480+
noise_scheduler_state = noise_scheduler.create_state()
480481

481482
# Initialize our training
482483
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -513,7 +514,7 @@ def compute_loss(params):
513514

514515
# Add noise to the latents according to the noise magnitude at each timestep
515516
# (this is the forward diffusion process)
516-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
517+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
517518

518519
# Get the text embedding for conditioning
519520
if args.train_text_encoder:

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def collate_fn(examples):
417417
noise_scheduler = FlaxDDPMScheduler(
418418
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
419419
)
420+
noise_scheduler_state = noise_scheduler.create_state()
420421

421422
# Initialize our training
422423
rng = jax.random.PRNGKey(args.seed)
@@ -449,7 +450,7 @@ def compute_loss(params):
449450

450451
# Add noise to the latents according to the noise magnitude at each timestep
451452
# (this is the forward diffusion process)
452-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
453+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
453454

454455
# Get the text embedding for conditioning
455456
encoder_hidden_states = text_encoder(

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def update_fn(updates, state, params=None):
505505
noise_scheduler = FlaxDDPMScheduler(
506506
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
507507
)
508+
noise_scheduler_state = noise_scheduler.create_state()
508509

509510
# Initialize our training
510511
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -531,7 +532,7 @@ def compute_loss(params):
531532
0,
532533
noise_scheduler.config.num_train_timesteps,
533534
)
534-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
535+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
535536
encoder_hidden_states = state.apply_fn(
536537
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
537538
)[0]

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def loop_body(step, args):
261261
)
262262

263263
# scale the initial noise by the standard deviation required by the scheduler
264-
latents = latents * self.scheduler.init_noise_sigma
264+
latents = latents * params["scheduler"].init_noise_sigma
265+
265266
if DEBUG:
266267
# run with python for loop
267268
for i in range(num_inference_steps):
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
16+
import flax
17+
import jax.numpy as jnp
18+
19+
from .scheduling_utils_flax import broadcast_to_shape_from_left
20+
21+
22+
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
23+
"""
24+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
25+
(1-beta) over time from t = [0,1].
26+
27+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
28+
to that part of the diffusion process.
29+
30+
31+
Args:
32+
num_diffusion_timesteps (`int`): the number of betas to produce.
33+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
34+
prevent singularities.
35+
36+
Returns:
37+
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
38+
"""
39+
40+
def alpha_bar(time_step):
41+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
42+
43+
betas = []
44+
for i in range(num_diffusion_timesteps):
45+
t1 = i / num_diffusion_timesteps
46+
t2 = (i + 1) / num_diffusion_timesteps
47+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
48+
return jnp.array(betas, dtype=dtype)
49+
50+
51+
@flax.struct.dataclass
52+
class SchedulerCommonState:
53+
alphas: jnp.ndarray
54+
betas: jnp.ndarray
55+
alphas_cumprod: jnp.ndarray
56+
57+
58+
def create_common_state(scheduler):
59+
config = scheduler.config
60+
61+
if config.trained_betas is not None:
62+
betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
63+
elif config.beta_schedule == "linear":
64+
betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
65+
elif config.beta_schedule == "scaled_linear":
66+
# this schedule is very specific to the latent diffusion model.
67+
betas = (
68+
jnp.linspace(
69+
config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
70+
)
71+
** 2
72+
)
73+
elif config.beta_schedule == "squaredcos_cap_v2":
74+
# Glide cosine schedule
75+
betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
76+
else:
77+
raise NotImplementedError(
78+
f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
79+
)
80+
81+
alphas = 1.0 - betas
82+
83+
alphas_cumprod = jnp.cumprod(alphas, axis=0)
84+
85+
return SchedulerCommonState(
86+
alphas=alphas,
87+
betas=betas,
88+
alphas_cumprod=alphas_cumprod,
89+
)
90+
91+
92+
def add_noise_common(
93+
state: SchedulerCommonState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
94+
):
95+
alphas_cumprod = state.alphas_cumprod
96+
97+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
98+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
99+
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
100+
101+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
102+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
103+
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
104+
105+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
106+
return noisy_samples

0 commit comments

Comments
 (0)