diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index ed62b5fe579a..d604b1d0efd0 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -28,12 +28,17 @@ from requests import HTTPError from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import WEIGHTS_NAME, load_state_dict -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from .modeling_utils import load_state_dict +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + logging, +) -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" - logger = logging.get_logger(__name__) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index ef1ead9ecf0d..6935fc12deb3 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -24,10 +24,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging - - -WEIGHTS_NAME = "diffusion_pytorch_model.bin" +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging logger = logging.get_logger(__name__) diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py index 3c2a0b482922..863cf07488e4 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/onnx_utils.py @@ -24,16 +24,13 @@ from huggingface_hub import hf_hub_download -from .utils import is_onnx_available, logging +from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort -ONNX_WEIGHTS_NAME = "model.onnx" - - logger = logging.get_logger(__name__) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 15334e240c72..be3429e26ac5 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -30,10 +30,8 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .modeling_utils import WEIGHTS_NAME -from .onnx_utils import ONNX_WEIGHTS_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging INDEX_FILE = "diffusion_pytorch_model.bin" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c00a28e1058f..b63dbd2b285c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -47,6 +47,9 @@ CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 7689126b960b..145c26e34215 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -46,11 +46,10 @@ UNet2DModel, VQModel, ) -from diffusers.modeling_utils import WEIGHTS_NAME from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils import CONFIG_NAME +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer