Skip to content

Commit 0ddc5bf

Browse files
authored
fix mixed precision training on train_dreambooth_inpaint_lora (#3138)
cast to weight dtype
1 parent c5933c9 commit 0ddc5bf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def collate_fn(examples):
735735
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
736736
for mask in masks
737737
]
738-
)
738+
).to(dtype=weight_dtype)
739739
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
740740

741741
# Sample noise that we'll add to the latents

0 commit comments

Comments
 (0)