Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,17 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand Down
69 changes: 39 additions & 30 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,13 @@ def main():
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
for param in unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
Expand Down Expand Up @@ -888,39 +894,42 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)

# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
# Final inference
# Load previous pipeline
if args.validation_prompt is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not validation_prompt was passed we must not run this step.

pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to use load_lora_weights() instead of load_attn_procs().


# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])

if accelerator.is_main_process:
for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)

accelerator.end_training()

Expand Down