Skip to content

Commit 752fb76

Browse files
committed
Remove scheduling_common_flax and some renames
1 parent 4ee3fce commit 752fb76

File tree

7 files changed

+118
-131
lines changed

7 files changed

+118
-131
lines changed

src/diffusers/schedulers/scheduling_common_flax.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
2525
from ..utils import deprecate
26-
from .scheduling_common_flax import SchedulerCommonState, add_noise_common, create_common_state
2726
from .scheduling_utils_flax import (
2827
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
28+
CommonSchedulerState,
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
31+
add_noise_common,
3132
)
3233

3334

3435
@flax.struct.dataclass
3536
class DDIMSchedulerState:
36-
common: SchedulerCommonState
37+
common: CommonSchedulerState
3738
final_alpha_cumprod: jnp.ndarray
3839

3940
# setable values
@@ -44,7 +45,7 @@ class DDIMSchedulerState:
4445
@classmethod
4546
def create(
4647
cls,
47-
common: SchedulerCommonState,
48+
common: CommonSchedulerState,
4849
final_alpha_cumprod: jnp.ndarray,
4950
init_noise_sigma: jnp.ndarray,
5051
timesteps: jnp.ndarray,
@@ -133,9 +134,9 @@ def __init__(
133134

134135
self.dtype = dtype
135136

136-
def create_state(self, common: Optional[SchedulerCommonState] = None) -> DDIMSchedulerState:
137+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
137138
if common is None:
138-
common = create_common_state(self)
139+
common = CommonSchedulerState.create(self)
139140

140141
# At every step in ddim, we are looking into the previous alphas_cumprod
141142
# For the final step, there is no previous alphas_cumprod because we are already at 0

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,26 @@
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
2525
from ..utils import deprecate
26-
from .scheduling_common_flax import SchedulerCommonState, add_noise_common, create_common_state
2726
from .scheduling_utils_flax import (
2827
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
28+
CommonSchedulerState,
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
31+
add_noise_common,
3132
)
3233

3334

3435
@flax.struct.dataclass
3536
class DDPMSchedulerState:
36-
common: SchedulerCommonState
37+
common: CommonSchedulerState
3738

3839
# setable values
3940
init_noise_sigma: jnp.ndarray
4041
timesteps: jnp.ndarray
4142
num_inference_steps: Optional[int] = None
4243

4344
@classmethod
44-
def create(cls, common: SchedulerCommonState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
45+
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
4546
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
4647

4748

@@ -116,9 +117,9 @@ def __init__(
116117

117118
self.dtype = dtype
118119

119-
def create_state(self, common: Optional[SchedulerCommonState] = None) -> DDPMSchedulerState:
120+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
120121
if common is None:
121-
common = create_common_state(self)
122+
common = CommonSchedulerState.create(self)
122123

123124
# standard deviation of the initial noise distribution
124125
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)

src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
2525
from ..utils import deprecate
26-
from .scheduling_common_flax import SchedulerCommonState, add_noise_common, create_common_state
2726
from .scheduling_utils_flax import (
2827
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
28+
CommonSchedulerState,
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
31+
add_noise_common,
3132
)
3233

3334

3435
@flax.struct.dataclass
3536
class DPMSolverMultistepSchedulerState:
36-
common: SchedulerCommonState
37+
common: CommonSchedulerState
3738
alpha_t: jnp.ndarray
3839
sigma_t: jnp.ndarray
3940
lambda_t: jnp.ndarray
@@ -52,7 +53,7 @@ class DPMSolverMultistepSchedulerState:
5253
@classmethod
5354
def create(
5455
cls,
55-
common: SchedulerCommonState,
56+
common: CommonSchedulerState,
5657
alpha_t: jnp.ndarray,
5758
sigma_t: jnp.ndarray,
5859
lambda_t: jnp.ndarray,
@@ -177,9 +178,9 @@ def __init__(
177178

178179
self.dtype = dtype
179180

180-
def create_state(self, common: Optional[SchedulerCommonState] = None) -> DPMSolverMultistepSchedulerState:
181+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
181182
if common is None:
182-
common = create_common_state(self)
183+
common = CommonSchedulerState.create(self)
183184

184185
# Currently we only support VP-type noise schedule
185186
alpha_t = jnp.sqrt(common.alphas_cumprod)

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from scipy import integrate
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from .scheduling_common_flax import SchedulerCommonState, create_common_state
2423
from .scheduling_utils_flax import (
2524
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
25+
CommonSchedulerState,
2626
FlaxSchedulerMixin,
2727
FlaxSchedulerOutput,
2828
broadcast_to_shape_from_left,
@@ -31,7 +31,7 @@
3131

3232
@flax.struct.dataclass
3333
class LMSDiscreteSchedulerState:
34-
common: SchedulerCommonState
34+
common: CommonSchedulerState
3535

3636
# setable values
3737
init_noise_sigma: jnp.ndarray
@@ -44,7 +44,7 @@ class LMSDiscreteSchedulerState:
4444

4545
@classmethod
4646
def create(
47-
cls, common: SchedulerCommonState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
47+
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
4848
):
4949
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
5050

@@ -103,9 +103,9 @@ def __init__(
103103
):
104104
self.dtype = dtype
105105

106-
def create_state(self, common: Optional[SchedulerCommonState] = None) -> LMSDiscreteSchedulerState:
106+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
107107
if common is None:
108-
common = create_common_state(self)
108+
common = CommonSchedulerState.create(self)
109109

110110
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
111111
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25-
from .scheduling_common_flax import SchedulerCommonState, add_noise_common, create_common_state
2625
from .scheduling_utils_flax import (
2726
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
27+
CommonSchedulerState,
2828
FlaxSchedulerMixin,
2929
FlaxSchedulerOutput,
30+
add_noise_common,
3031
)
3132

3233

3334
@flax.struct.dataclass
3435
class PNDMSchedulerState:
35-
common: SchedulerCommonState
36+
common: CommonSchedulerState
3637
final_alpha_cumprod: jnp.ndarray
3738

3839
# setable values
@@ -51,7 +52,7 @@ class PNDMSchedulerState:
5152
@classmethod
5253
def create(
5354
cls,
54-
common: SchedulerCommonState,
55+
common: CommonSchedulerState,
5556
final_alpha_cumprod: jnp.ndarray,
5657
init_noise_sigma: jnp.ndarray,
5758
timesteps: jnp.ndarray,
@@ -139,9 +140,9 @@ def __init__(
139140
# mainly at formula (9), (12), (13) and the Algorithm 2.
140141
self.pndm_order = 4
141142

142-
def create_state(self, common: Optional[SchedulerCommonState] = None) -> PNDMSchedulerState:
143+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
143144
if common is None:
144-
common = create_common_state(self)
145+
common = CommonSchedulerState.create(self)
145146

146147
# At every step in ddim, we are looking into the previous alphas_cumprod
147148
# For the final step, there is no previous alphas_cumprod because we are already at 0

0 commit comments

Comments
 (0)