diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py index e840565dd5c1..3c2a0b482922 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/onnx_utils.py @@ -38,8 +38,6 @@ class OnnxRuntimeModel: - base_model_prefix = "onnx_model" - def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index b51b8b7e598c..847513bf15ea 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -30,7 +30,10 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .utils import DIFFUSERS_CACHE, BaseOutput, logging +from .modeling_utils import WEIGHTS_NAME +from .onnx_utils import ONNX_WEIGHTS_NAME +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging INDEX_FILE = "diffusion_pytorch_model.bin" @@ -285,6 +288,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, @@ -293,6 +311,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + allow_patterns=allow_patterns, ) else: cached_folder = pretrained_model_name_or_path diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e90bccb568ac..102a55a93e4b 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import os import random import tempfile import unittest @@ -45,8 +46,11 @@ UNet2DModel, VQModel, ) +from diffusers.modeling_utils import WEIGHTS_NAME from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils import CONFIG_NAME from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -707,6 +711,27 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() + def test_smart_download(self): + model_id = "hf-internal-testing/unet-pipeline-dummy" + with tempfile.TemporaryDirectory() as tmpdirname: + _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) + local_repo_name = "--".join(["models"] + model_id.split("/")) + snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") + snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) + + # inspect all downloaded files to make sure that everything is included + assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + # let's make sure the super large numpy file: + # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy + # is not downloaded, but all the expected ones + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) + @property def dummy_safety_checker(self): def check(images, *args, **kwargs):