diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 44c494e3bf9d..57d8358cf2de 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -472,9 +472,7 @@ def collate_fn(examples): apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer ) - noise_scheduler = FlaxDDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 - ) + noise_scheduler = FlaxDDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Initialize our training train_rngs = jax.random.split(rng, jax.local_device_count()) @@ -528,24 +526,31 @@ def compute_loss(params): {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True ) noise_pred = unet_outputs.sample + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0) - noise, noise_prior = jnp.split(noise, 2, axis=0) + target, target_prior = jnp.split(target, 2, axis=0) # Compute instance loss - loss = (noise - noise_pred) ** 2 + loss = (target - noise_pred) ** 2 loss = loss.mean() # Compute prior loss - prior_loss = (noise_prior - noise_pred_prior) ** 2 + prior_loss = (target_prior - noise_pred_prior) ** 2 prior_loss = prior_loss.mean() # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = (noise - noise_pred) ** 2 + loss = (target - noise_pred) ** 2 loss = loss.mean() return loss diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 056ff6bad0f4..06da292088de 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -414,9 +414,7 @@ def collate_fn(examples): state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer) - noise_scheduler = FlaxDDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 - ) + noise_scheduler = FlaxDDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Initialize our training rng = jax.random.PRNGKey(args.seed) @@ -461,7 +459,14 @@ def compute_loss(params): # Predict the noise residual and compute loss unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True) noise_pred = unet_outputs.sample - loss = (noise - noise_pred) ** 2 + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = (target - noise_pred) ** 2 loss = loss.mean() return loss diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index a9fa9e36931b..6c834da830c8 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -502,9 +502,7 @@ def update_fn(updates, state, params=None): state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx) - noise_scheduler = FlaxDDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 - ) + noise_scheduler = FlaxDDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Initialize our training train_rngs = jax.random.split(rng, jax.local_device_count()) @@ -539,7 +537,14 @@ def compute_loss(params): {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False ) noise_pred = unet_outputs.sample - loss = (noise - noise_pred) ** 2 + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = (target - noise_pred) ** 2 loss = loss.mean() return loss diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index a6f7eace2637..b0347932d1d6 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -322,5 +322,22 @@ def add_noise( noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + def get_velocity( + self, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, sample.shape) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, sample.shape) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 84186546ed9e..f61d8b275e01 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -314,5 +314,22 @@ def add_noise( noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + def get_velocity( + self, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, sample.shape) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, sample.shape) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps