Skip to content

Commit e607a58

Browse files
authored
[Examples] Fix type-casting issue in the ControlNet training script (#2994)
* fix: norm group test for UNet3D. * fix: type-casting issue in controlnet training.
1 parent ea39cd7 commit e607a58

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -972,8 +972,10 @@ def load_model_hook(models, input_dir):
972972
noisy_latents,
973973
timesteps,
974974
encoder_hidden_states=encoder_hidden_states,
975-
down_block_additional_residuals=down_block_res_samples,
976-
mid_block_additional_residual=mid_block_res_sample,
975+
down_block_additional_residuals=[
976+
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
977+
],
978+
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
977979
).sample
978980

979981
# Get the target for loss depending on the prediction type

0 commit comments

Comments
 (0)