Skip to content

Fix monkey-patch for text_encoder LoRA #3490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,15 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
Comment on lines -950 to +949
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'd not actually mind referring users to read this issue comment you made:
#3490 (comment)

def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment: if you load lora with e.g. pipeline.load_lora_weights("experiments/base_experiment") more than once, then this monkey patch becomes recursive!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! Indeed, it seems that might be the case. I wonder if it might be better to create a mechanism to remove the moneky-patch. @sayakpaul WDYT?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're currently discussing this internally, and will keep y'all posted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile,

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

@rvorias could you elaborate what you mean here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have load_lora_weights parse a kwarg override_text_encoder_forward:bool, then pass that to _modify_text_encoder. Then condition line 957 on this flag.

This is useful in contexts where you want to load arbitrary lora weights on the fly in a long-running SD inference engine.
Right now, calling load_lora_weights multiple times causes you to override the forward function multiple times and thus the lora addition term will get nested.

If you add the flag+condition you can still have the new lora weights to load, but you don't override the forward again and again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be willing to open a PR for this? We're more than happy to help you with that :-)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return new_forward

# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down
79 changes: 78 additions & 1 deletion tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
Expand Down Expand Up @@ -43,15 +44,33 @@ def create_unet_lora_layers(unet: nn.Module):
return lora_attn_procs, unet_lora_layers


def create_text_encoder_lora_layers(text_encoder: nn.Module):
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
return text_lora_attn_procs


def create_text_encoder_lora_layers(text_encoder: nn.Module):
text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
return text_encoder_lora_layers


def set_lora_up_weights(text_lora_attn_procs, randn_weight=False):
for _, attn_proc in text_lora_attn_procs.items():
# set up.weights
for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
layer_module.up.weight = torch.nn.Parameter(weight)


class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
Expand Down Expand Up @@ -212,3 +231,61 @@ def test_lora_save_load_legacy(self):

# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77

inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def test_text_encoder_lora_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)

dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)

# create lora_attn_procs with zeroed out up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=False)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"

# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"