Skip to content

Commit 262d539

Browse files
Correct multi gpu dreambooth (#3673)
Correct multi gpu
1 parent 0fc2fb7 commit 262d539

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ def compute_text_embeddings(prompt):
12111211
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
12121212
)
12131213

1214-
if unet.config.in_channels == channels * 2:
1214+
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
12151215
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
12161216

12171217
if args.class_labels_conditioning == "timesteps":

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ def compute_text_embeddings(prompt):
11561156
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
11571157
)
11581158

1159-
if unet.config.in_channels == channels * 2:
1159+
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
11601160
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
11611161

11621162
if args.class_labels_conditioning == "timesteps":

0 commit comments

Comments
 (0)