Skip to content

[LoRA ]fix flux lora loader when return_metadata is true for non-diffusers #11716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,18 +2031,36 @@ def lora_state_dict(
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)

is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)

is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)

# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
Expand All @@ -2061,12 +2079,13 @@ def lora_state_dict(
)

if return_alphas or return_lora_metadata:
outputs = [state_dict]
if return_alphas:
outputs.append(network_alphas)
if return_lora_metadata:
outputs.append(metadata)
return tuple(outputs)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
else:
return state_dict

Expand Down Expand Up @@ -2785,6 +2804,15 @@ def _get_weight_shape(weight: torch.Tensor):

raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")

@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict


# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def load_lora_adapter(
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
metadata:
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str):
metadata = f.metadata() or {}

metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
if metadata:
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
else:
return None