diff --git a/setup.py b/setup.py index 598291fa4546..fd82a2adb2c6 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.20.2", + "huggingface-hub>=0.23.2", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index ca233a4158bc..9413be5e4eed 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -9,7 +9,7 @@ "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.20.2", + "huggingface-hub": "huggingface-hub>=0.23.2", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index afcee2471662..5604879f40ab 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -18,13 +18,19 @@ import inspect import os from collections import OrderedDict +from pathlib import Path from typing import List, Optional, Union import safetensors import torch +from huggingface_hub.utils import EntryNotFoundError from ..utils import ( + SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, + WEIGHTS_INDEX_NAME, + _add_variant, + _get_model_file, is_accelerate_available, is_torch_version, logging, @@ -175,3 +181,52 @@ def load(module: torch.nn.Module, prefix: str = ""): load(model_to_load) return error_msgs + + +def _fetch_index_file( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + resume_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ) + else: + index_file_in_repo = Path( + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ).as_posix() + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + index_file = Path(index_file) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index da4f798202d7..5f011c46e43d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -16,6 +16,7 @@ import inspect import itertools +import json import os import re from collections import OrderedDict @@ -25,7 +26,7 @@ import safetensors import torch -from huggingface_hub import create_repo +from huggingface_hub import create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn @@ -33,9 +34,12 @@ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, _add_variant, + _get_checkpoint_shard_files, _get_model_file, deprecate, is_accelerate_available, @@ -49,6 +53,7 @@ ) from .model_loading_utils import ( _determine_device_map, + _fetch_index_file, _load_state_dict_into_model, load_model_dict_into_meta, load_state_dict, @@ -57,6 +62,8 @@ logger = logging.get_logger(__name__) +_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") + if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True @@ -263,6 +270,7 @@ def save_pretrained( save_function: Optional[Callable] = None, safe_serialization: bool = True, variant: Optional[str] = None, + max_shard_size: Union[int, str] = "5GB", push_to_hub: bool = False, **kwargs, ): @@ -285,6 +293,10 @@ def save_pretrained( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your @@ -296,6 +308,14 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weight_name_split = weights_name.split(".") + if len(weight_name_split) in [2, 3]: + weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) + else: + raise ValueError(f"Invalid {weights_name} provided.") + os.makedirs(save_directory, exist_ok=True) if push_to_hub: @@ -317,18 +337,58 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - # Save the model - if safe_serialization: - safetensors.torch.save_file( - state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"} + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) + + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + else: + torch.save(shard, filepath) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." ) else: - torch.save(state_dict, Path(save_directory, weights_name).as_posix()) - - logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}") + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") if push_to_hub: # Create a new empty model card and eventually tag it @@ -566,6 +626,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file = _fetch_index_file( + is_local=is_local, + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder or "", + use_safetensors=use_safetensors, + cache_dir=cache_dir, + variant=variant, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + # load model model_file = None if from_flax: @@ -590,7 +676,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: - if use_safetensors: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + + elif use_safetensors and not is_sharded: try: model_file = _get_model_file( pretrained_model_name_or_path, @@ -606,11 +706,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, commit_hash=commit_hash, ) + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") if not allow_pickle: - raise e - pass - if model_file is None: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), @@ -632,7 +737,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None: + if device_map is None and not is_sharded: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) @@ -670,7 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P try: accelerate.load_checkpoint_and_dispatch( model, - model_file, + model_file if not is_sharded else sharded_ckpt_cached_folder, device_map, max_memory=max_memory, offload_folder=offload_folder, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7ab0a94e5677..1612ca5ae4c0 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -28,9 +28,11 @@ MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) from .deprecation_utils import deprecate @@ -40,6 +42,7 @@ from .hub_utils import ( PushToHubMixin, _add_variant, + _get_checkpoint_shard_files, _get_model_file, extract_commit_hash, http_user_agent, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index bc4268a32ac5..553ac5d1bb27 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -28,9 +28,11 @@ CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" +WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFETENSORS_FILE_EXTENSION = "safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 83f02848fcf4..d0253ff474d9 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -14,6 +14,7 @@ # limitations under the License. +import json import os import re import sys @@ -29,6 +30,8 @@ ModelCardData, create_repo, hf_hub_download, + model_info, + snapshot_download, upload_folder, ) from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE @@ -393,6 +396,109 @@ def _get_model_file( ) +# Adapted from +# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 +# Differences are in parallelization of shard downloads and checking if shards are present. + + +def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): + shards_path = os.path.join(local_dir, subfolder) + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if not os.path.exists(shard_file): + raise ValueError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + +def _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + proxies=None, + resume_download=False, + local_files_only=False, + token=None, + user_agent=None, + revision=None, + subfolder="", +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + original_shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + shards_path = os.path.join(pretrained_model_name_or_path, subfolder) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + _check_if_shards_exist_locally( + pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + return pretrained_model_name_or_path, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + allow_patterns = original_shard_filenames + ignore_patterns = ["*.json", "*.md"] + if not local_files_only: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e + + # If `local_files_only=True`, `cached_folder` may not contain all the shard files. + if local_files_only: + _check_if_shards_exist_locally( + local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + + return cached_folder, sharded_metadata + + class PushToHubMixin: """ A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 6f70f5888910..475d6e415eea 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -131,7 +131,6 @@ except importlib_metadata.PackageNotFoundError: _unidecode_available = False - _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a72290b75f00..a8564e0baf7b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -14,6 +14,8 @@ # limitations under the License. import inspect +import json +import os import tempfile import traceback import unittest @@ -37,7 +39,7 @@ XFormersAttnProcessor, ) from diffusers.training_utils import EMAModel -from diffusers.utils import is_torch_npu_available, is_xformers_available, logging +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging from diffusers.utils.testing_utils import ( CaptureLogger, get_python_version, @@ -129,7 +131,9 @@ def test_one_request_upon_cached(self): ) download_requests = [r.method for r in m.request_history] - assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model" + assert ( + download_requests.count("HEAD") == 3 + ), "3 HEAD requests one for config, one for model, and one for shard index file." assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -142,8 +146,8 @@ def test_one_request_upon_cached(self): cache_requests = [r.method for r in m.request_history] assert ( - "HEAD" == cache_requests[0] and len(cache_requests) == 1 - ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + "HEAD" == cache_requests[0] and len(cache_requests) == 2 + ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: @@ -866,6 +870,41 @@ def test_model_parallelism(self): self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_torch_gpu + def test_sharded_checkpoints(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + if model._no_split_modules is None: + return + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + # Now check if the right number of shards exists. First, let's get the number of shards. + # Since this number can be dependent on the model being tested, it's important that we calculate it + # instead of hardcoding it. + with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: + weight_map_dict = json.load(f)["weight_map"] + first_key = list(weight_map_dict.keys())[0] + weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors + expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) + + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + self.assertTrue(actual_num_shards == expected_num_shards) + + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") + + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index f6dcebf29933..ca5d964f7609 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -21,6 +21,7 @@ from collections import OrderedDict import torch +from huggingface_hub import snapshot_download from parameterized import parameterized from pytest import mark @@ -1034,6 +1035,25 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_local(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 69d06d980af1..789f9f7f6686 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -29,6 +29,7 @@ import requests_mock import safetensors.torch import torch +import torch.nn as nn from parameterized import parameterized from PIL import Image from requests.exceptions import HTTPError @@ -135,6 +136,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): class CustomEncoder(ModelMixin, ConfigMixin): def __init__(self): super().__init__() + self.linear = nn.Linear(3, 3) class CustomPipeline(DiffusionPipeline):