Skip to content

Commit 2b7d4a5

Browse files
[DeviceMap] Make sure stable diffusion can be loaded from older trans… (#860)
[DeviceMap] Make sure stable diffusion can be loaded from older transformers versiosn
1 parent 93a81a3 commit 2b7d4a5

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import diffusers
2727
import PIL
2828
from huggingface_hub import snapshot_download
29+
from packaging import version
2930
from PIL import Image
3031
from tqdm.auto import tqdm
3132

@@ -45,6 +46,7 @@
4546

4647

4748
if is_transformers_available():
49+
import transformers
4850
from transformers import PreTrainedModel
4951

5052

@@ -505,11 +507,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
505507
loading_kwargs["provider"] = provider
506508
loading_kwargs["sess_options"] = sess_options
507509

508-
if (
509-
issubclass(class_obj, diffusers.ModelMixin)
510-
or is_transformers_available()
510+
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
511+
is_transformers_model = (
512+
is_transformers_available()
511513
and issubclass(class_obj, PreTrainedModel)
512-
):
514+
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
515+
)
516+
517+
if is_diffusers_model or is_transformers_model:
513518
loading_kwargs["device_map"] = device_map
514519

515520
# check if the module is in a subdirectory

0 commit comments

Comments
 (0)