Skip to content

Commit fd5c3c0

Browse files
misc fixes (#2282)
Co-authored-by: Patrick von Platen <[email protected]>
1 parent 648090e commit fd5c3c0

File tree

8 files changed

+44
-23
lines changed

8 files changed

+44
-23
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848
"--pipeline_type",
4949
default=None,
5050
type=str,
51-
help="The pipeline type. If `None` pipeline will be automatically inferred.",
51+
help=(
52+
"The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
53+
". If `None` pipeline will be automatically inferred."
54+
),
5255
)
5356
parser.add_argument(
5457
"--image_size",
@@ -65,7 +68,7 @@
6568
type=str,
6669
help=(
6770
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
68-
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
71+
" Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
6972
),
7073
)
7174
parser.add_argument(
@@ -79,8 +82,7 @@
7982
)
8083
parser.add_argument(
8184
"--upcast_attention",
82-
default=False,
83-
type=bool,
85+
action="store_true",
8486
help=(
8587
"Whether the attention computation should always be upcasted. This is necessary when running stable"
8688
" diffusion 2.1."
@@ -111,5 +113,6 @@
111113
num_in_channels=args.num_in_channels,
112114
upcast_attention=args.upcast_attention,
113115
from_safetensors=args.from_safetensors,
116+
device=args.device,
114117
)
115118
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/pipelines/audio_diffusion/mel.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,12 @@
1313
# limitations under the License.
1414

1515

16-
import warnings
16+
import numpy as np # noqa: E402
1717

1818
from ...configuration_utils import ConfigMixin, register_to_config
1919
from ...schedulers.scheduling_utils import SchedulerMixin
2020

2121

22-
warnings.filterwarnings("ignore")
23-
24-
import numpy as np # noqa: E402
25-
26-
2722
try:
2823
import librosa # noqa: E402
2924

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
4040
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
4141

42-
from ...utils import is_omegaconf_available, is_safetensors_available
42+
from ...utils import is_omegaconf_available, is_safetensors_available, logging
4343
from ...utils.import_utils import BACKENDS_MAPPING
4444

4545

46+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47+
48+
4649
def shave_segments(path, n_shave_prefix_segments=1):
4750
"""
4851
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -801,11 +804,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
801804
corresponding to the original architecture. If `None`, will be
802805
automatically inferred by looking for a key that only exists in SD2.0 models.
803806
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable
804-
Siffusion v2
807+
Diffusion v2
805808
Base. Use 768 for Stable Diffusion v2.
806809
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
807810
v1.X and Stable
808-
Siffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
811+
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
809812
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
810813
inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
811814
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
@@ -820,6 +823,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
820823
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
821824
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
822825
"""
826+
if prediction_type == "v-prediction":
827+
prediction_type = "v_prediction"
823828

824829
if not is_omegaconf_available():
825830
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
@@ -957,6 +962,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
957962
# Convert the text model.
958963
if model_type is None:
959964
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
965+
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
960966

961967
if model_type == "FrozenOpenCLIPEmbedder":
962968
text_model = convert_open_clip_checkpoint(checkpoint)

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def step(
305305
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
306306

307307
if eta > 0:
308-
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
309308
device = model_output.device
310309
if variance_noise is not None and generator is not None:
311310
raise ValueError(

src/diffusers/schedulers/scheduling_unclip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ def __init__(
106106
clip_sample: bool = True,
107107
clip_sample_range: Optional[float] = 1.0,
108108
prediction_type: str = "epsilon",
109+
beta_schedule: str = "squaredcos_cap_v2",
109110
):
110-
# beta scheduler is "squaredcos_cap_v2"
111+
if beta_schedule != "squaredcos_cap_v2":
112+
raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'")
113+
111114
self.betas = betas_for_alpha_bar(num_train_timesteps)
112115

113116
self.alphas = 1.0 - self.betas

src/diffusers/utils/testing_utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,36 @@
1717
from packaging import version
1818

1919
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
20+
from .logging import get_logger
2021

2122

2223
global_rng = random.Random()
2324

25+
logger = get_logger(__name__)
2426

2527
if is_torch_available():
2628
import torch
2729

28-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
29-
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
30-
"1.12"
31-
)
30+
if "DIFFUSERS_TEST_DEVICE" in os.environ:
31+
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
3232

33-
if is_torch_higher_equal_than_1_12:
34-
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
35-
mps_backend_registered = hasattr(torch.backends, "mps")
36-
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
33+
available_backends = ["cuda", "cpu", "mps"]
34+
if torch_device not in available_backends:
35+
raise ValueError(
36+
f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:"
37+
f" {available_backends}"
38+
)
39+
logger.info(f"torch_device overrode to {torch_device}")
40+
else:
41+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
42+
is_torch_higher_equal_than_1_12 = version.parse(
43+
version.parse(torch.__version__).base_version
44+
) >= version.parse("1.12")
45+
46+
if is_torch_higher_equal_than_1_12:
47+
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
48+
mps_backend_registered = hasattr(torch.backends, "mps")
49+
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
3750

3851

3952
def torch_all_close(a, b, *args, **kwargs):

tests/pipelines/unclip/test_unclip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3232
pipeline_class = UnCLIPPipeline
33+
test_xformers_attention = False
3334

3435
required_optional_params = [
3536
"generator",

tests/test_pipelines_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def _test_inference_batch_single_identical(
259259
# Taking the median of the largest <n> differences
260260
# is resilient to outliers
261261
diff = np.abs(output_batch[0][0] - output[0][0])
262+
diff = diff.flatten()
262263
diff.sort()
263264
max_diff = np.median(diff[-5:])
264265
else:

0 commit comments

Comments
 (0)