diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c3a78eae34d7..9992292e30aa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -991,6 +991,17 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ca699c863eb6..c8efbddd0b44 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -460,7 +460,13 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -888,39 +894,42 @@ def collate_fn(examples): ignore_patterns=["step_*", "epoch_*"], ) - # Final inference - # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype - ) - pipeline = pipeline.to(accelerator.device) + # Final inference + # Load previous pipeline + if args.validation_prompt is not None: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) - # run inference - generator = torch.Generator(device=accelerator.device) - if args.seed is not None: - generator = generator.manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - if accelerator.is_main_process: - for tracker in accelerator.trackers: - if len(images) != 0: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if len(images) != 0: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) accelerator.end_training()