-
Notifications
You must be signed in to change notification settings - Fork 6.1k
misc fixes #2282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
misc fixes #2282
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,17 +13,12 @@ | |
# limitations under the License. | ||
|
||
|
||
import warnings | ||
import numpy as np # noqa: E402 | ||
|
||
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...schedulers.scheduling_utils import SchedulerMixin | ||
|
||
|
||
warnings.filterwarnings("ignore") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This warning filter will apply globally because this module ends up being imported at |
||
|
||
import numpy as np # noqa: E402 | ||
|
||
|
||
try: | ||
import librosa # noqa: E402 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,10 +39,13 @@ | |
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline | ||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
|
||
from ...utils import is_omegaconf_available, is_safetensors_available | ||
from ...utils import is_omegaconf_available, is_safetensors_available, logging | ||
from ...utils.import_utils import BACKENDS_MAPPING | ||
|
||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
def shave_segments(path, n_shave_prefix_segments=1): | ||
""" | ||
Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
|
@@ -801,11 +804,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt( | |
corresponding to the original architecture. If `None`, will be | ||
automatically inferred by looking for a key that only exists in SD2.0 models. | ||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable | ||
Siffusion v2 | ||
Diffusion v2 | ||
Base. Use 768 for Stable Diffusion v2. | ||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion | ||
v1.X and Stable | ||
Siffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2. | ||
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. | ||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically | ||
inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", | ||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of | ||
|
@@ -820,6 +823,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt( | |
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A | ||
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. | ||
""" | ||
if prediction_type == "v-prediction": | ||
prediction_type = "v_prediction" | ||
Comment on lines
+826
to
+827
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just in case because of previous script docs |
||
|
||
if not is_omegaconf_available(): | ||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) | ||
|
@@ -957,6 +962,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( | |
# Convert the text model. | ||
if model_type is None: | ||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") | ||
|
||
if model_type == "FrozenOpenCLIPEmbedder": | ||
text_model = convert_open_clip_checkpoint(checkpoint) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,7 +305,6 @@ def step( | |
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | ||
|
||
if eta > 0: | ||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. left over comment before we switched to rand_tensor |
||
device = model_output.device | ||
if variance_noise is not None and generator is not None: | ||
raise ValueError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,23 +17,36 @@ | |
from packaging import version | ||
|
||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available | ||
from .logging import get_logger | ||
|
||
|
||
global_rng = random.Random() | ||
|
||
logger = get_logger(__name__) | ||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( | ||
"1.12" | ||
) | ||
if "DIFFUSERS_TEST_DEVICE" in os.environ: | ||
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] | ||
|
||
if is_torch_higher_equal_than_1_12: | ||
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details | ||
mps_backend_registered = hasattr(torch.backends, "mps") | ||
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device | ||
available_backends = ["cuda", "cpu", "mps"] | ||
if torch_device not in available_backends: | ||
raise ValueError( | ||
f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:" | ||
f" {available_backends}" | ||
) | ||
logger.info(f"torch_device overrode to {torch_device}") | ||
Comment on lines
+30
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be helpful to force the tests to run on a particular device -- i.e. I use this snippet when I want to force the tests to run on the cpu when I'm using a macbook or a machine with cuda There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me! |
||
else: | ||
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
is_torch_higher_equal_than_1_12 = version.parse( | ||
version.parse(torch.__version__).base_version | ||
) >= version.parse("1.12") | ||
|
||
if is_torch_higher_equal_than_1_12: | ||
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details | ||
mps_backend_registered = hasattr(torch.backends, "mps") | ||
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device | ||
|
||
|
||
def torch_all_close(a, b, *args, **kwargs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers/src/diffusers/schedulers/scheduling_ddim.py
Line 278 in a7ca03a
The scheduler takes 'v_prediction'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!