Skip to content

Commit f653140

Browse files
committed
Define SchedulerOutput to use torch or flax arrays.
1 parent 2b24dba commit f653140

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/diffusers/schedulers/scheduling_utils.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,43 @@
1414
import warnings
1515
from dataclasses import dataclass
1616

17-
import torch
18-
1917
from ..utils import BaseOutput
18+
from ..utils import is_torch_available, is_flax_available
2019

2120

2221
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
2322

2423

25-
@dataclass
26-
class SchedulerOutput(BaseOutput):
27-
"""
28-
Base class for the scheduler's step function output.
24+
if is_torch_available():
25+
import torch
2926

30-
Args:
31-
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
32-
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
33-
denoising loop.
34-
"""
27+
@dataclass
28+
class SchedulerOutput(BaseOutput):
29+
"""
30+
Base class for the scheduler's step function output.
31+
32+
Args:
33+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
34+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
35+
denoising loop.
36+
"""
37+
38+
prev_sample: torch.FloatTensor
39+
40+
if is_flax_available():
41+
import jax.numpy as jnp
42+
43+
class SchedulerOutput(BaseOutput):
44+
"""
45+
Base class for the scheduler's step function output.
46+
47+
Args:
48+
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
49+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
50+
denoising loop.
51+
"""
3552

36-
prev_sample: torch.FloatTensor
53+
prev_sample: jnp.ndarray
3754

3855

3956
class SchedulerMixin:

0 commit comments

Comments
 (0)