|
| 1 | +# Copyright 2022 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import math |
| 15 | + |
| 16 | +import flax |
| 17 | +import jax.numpy as jnp |
| 18 | + |
| 19 | +from .scheduling_utils_flax import broadcast_to_shape_from_left |
| 20 | + |
| 21 | + |
| 22 | +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: |
| 23 | + """ |
| 24 | + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of |
| 25 | + (1-beta) over time from t = [0,1]. |
| 26 | +
|
| 27 | + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up |
| 28 | + to that part of the diffusion process. |
| 29 | +
|
| 30 | +
|
| 31 | + Args: |
| 32 | + num_diffusion_timesteps (`int`): the number of betas to produce. |
| 33 | + max_beta (`float`): the maximum beta to use; use values lower than 1 to |
| 34 | + prevent singularities. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs |
| 38 | + """ |
| 39 | + |
| 40 | + def alpha_bar(time_step): |
| 41 | + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 |
| 42 | + |
| 43 | + betas = [] |
| 44 | + for i in range(num_diffusion_timesteps): |
| 45 | + t1 = i / num_diffusion_timesteps |
| 46 | + t2 = (i + 1) / num_diffusion_timesteps |
| 47 | + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) |
| 48 | + return jnp.array(betas, dtype=dtype) |
| 49 | + |
| 50 | + |
| 51 | +@flax.struct.dataclass |
| 52 | +class SchedulerCommonState: |
| 53 | + alphas: jnp.ndarray |
| 54 | + betas: jnp.ndarray |
| 55 | + alphas_cumprod: jnp.ndarray |
| 56 | + |
| 57 | + |
| 58 | +def create_common_state(scheduler): |
| 59 | + config = scheduler.config |
| 60 | + |
| 61 | + if config.trained_betas is not None: |
| 62 | + betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) |
| 63 | + elif config.beta_schedule == "linear": |
| 64 | + betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) |
| 65 | + elif config.beta_schedule == "scaled_linear": |
| 66 | + # this schedule is very specific to the latent diffusion model. |
| 67 | + betas = ( |
| 68 | + jnp.linspace( |
| 69 | + config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype |
| 70 | + ) |
| 71 | + ** 2 |
| 72 | + ) |
| 73 | + elif config.beta_schedule == "squaredcos_cap_v2": |
| 74 | + # Glide cosine schedule |
| 75 | + betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) |
| 76 | + else: |
| 77 | + raise NotImplementedError( |
| 78 | + f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" |
| 79 | + ) |
| 80 | + |
| 81 | + alphas = 1.0 - betas |
| 82 | + |
| 83 | + alphas_cumprod = jnp.cumprod(alphas, axis=0) |
| 84 | + |
| 85 | + return SchedulerCommonState( |
| 86 | + alphas=alphas, |
| 87 | + betas=betas, |
| 88 | + alphas_cumprod=alphas_cumprod, |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +def add_noise_common( |
| 93 | + state: SchedulerCommonState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray |
| 94 | +): |
| 95 | + alphas_cumprod = state.alphas_cumprod |
| 96 | + |
| 97 | + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
| 98 | + sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
| 99 | + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) |
| 100 | + |
| 101 | + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
| 102 | + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
| 103 | + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) |
| 104 | + |
| 105 | + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise |
| 106 | + return noisy_samples |
0 commit comments