Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/api/loaders/single_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`]
- [`StableDiffusion3Pipeline`]
- [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`]
Expand All @@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]

## FromSingleFileMixin

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ The abstract from the paper is:

## Usage Example

_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._

Use the command below to log in:
Use the command below to log in:

```bash
huggingface-cli login
Expand Down Expand Up @@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability

## Loading the single checkpoint for the `StableDiffusion3Pipeline`

### Loading the single file checkpoint without T5

```python
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel

text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
torch_dtype=torch.float16,
text_encoder_3=None
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')
```

<Tip>
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip>
### Loading the single file checkpoint without T5

```python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')
```

## StableDiffusion3Pipeline

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)

Expand Down Expand Up @@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading=is_legacy_loading,
)

elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)

elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
Expand Down
32 changes: 23 additions & 9 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@
LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
Expand Down Expand Up @@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):


def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint)
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True

return False


def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True

return False
Expand Down Expand Up @@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint


def convert_ldm_clip_checkpoint(checkpoint):
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys())
text_model_dict = {}

remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)

for key in keys:
for prefix in remove_prefixes:
Expand Down Expand Up @@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)

elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)

elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
Expand All @@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)

elif is_open_clip_sd3_model(checkpoint):
prefix = "text_encoders.clip_g.transformer."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif (
is_open_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")

else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
Expand Down Expand Up @@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}

remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
remove_prefixes = ["text_encoders.t5xxl.transformer."]

for key in keys:
for prefix in remove_prefixes:
Expand Down