|
14 | 14 | import warnings |
15 | 15 | from dataclasses import dataclass |
16 | 16 |
|
17 | | -import torch |
18 | | - |
19 | 17 | from ..utils import BaseOutput |
| 18 | +from ..utils import is_torch_available, is_flax_available |
20 | 19 |
|
21 | 20 |
|
22 | 21 | SCHEDULER_CONFIG_NAME = "scheduler_config.json" |
23 | 22 |
|
24 | 23 |
|
25 | | -@dataclass |
26 | | -class SchedulerOutput(BaseOutput): |
27 | | - """ |
28 | | - Base class for the scheduler's step function output. |
| 24 | +if is_torch_available(): |
| 25 | + import torch |
29 | 26 |
|
30 | | - Args: |
31 | | - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): |
32 | | - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the |
33 | | - denoising loop. |
34 | | - """ |
| 27 | + @dataclass |
| 28 | + class SchedulerOutput(BaseOutput): |
| 29 | + """ |
| 30 | + Base class for the scheduler's step function output. |
| 31 | +
|
| 32 | + Args: |
| 33 | + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): |
| 34 | + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the |
| 35 | + denoising loop. |
| 36 | + """ |
| 37 | + |
| 38 | + prev_sample: torch.FloatTensor |
| 39 | + |
| 40 | +if is_flax_available(): |
| 41 | + import jax.numpy as jnp |
| 42 | + |
| 43 | + class SchedulerOutput(BaseOutput): |
| 44 | + """ |
| 45 | + Base class for the scheduler's step function output. |
| 46 | +
|
| 47 | + Args: |
| 48 | + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): |
| 49 | + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the |
| 50 | + denoising loop. |
| 51 | + """ |
35 | 52 |
|
36 | | - prev_sample: torch.FloatTensor |
| 53 | + prev_sample: jnp.ndarray |
37 | 54 |
|
38 | 55 |
|
39 | 56 | class SchedulerMixin: |
|
0 commit comments