diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5b1364195c6c..d3e6113dac8e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -26,6 +26,7 @@ import diffusers import PIL from huggingface_hub import snapshot_download +from packaging import version from PIL import Image from tqdm.auto import tqdm @@ -45,6 +46,7 @@ if is_transformers_available(): + import transformers from transformers import PreTrainedModel @@ -505,11 +507,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options - if ( - issubclass(class_obj, diffusers.ModelMixin) - or is_transformers_available() + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_transformers_model = ( + is_transformers_available() and issubclass(class_obj, PreTrainedModel) - ): + and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") + ) + + if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map # check if the module is in a subdirectory