diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index 07df6f201175..821c66b7237a 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -735,7 +735,7 @@ def collate_fn(examples): torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks ] - ) + ).to(dtype=weight_dtype) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents