Skip to content

Commit e2d9a9b

Browse files
andrehuangpatrickvonplatenwilliamberman
authored
fix the in-place modification in unet condition when using controlnet (#2586)
* fix the in-place modification in unet condition when using controlnet, which will cause backprop errors when training * add clone to mid block * fix-copies --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]>
1 parent f9cfb5a commit e2d9a9b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def forward(
598598
for down_block_res_sample, down_block_additional_residual in zip(
599599
down_block_res_samples, down_block_additional_residuals
600600
):
601-
down_block_res_sample += down_block_additional_residual
601+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
602602
new_down_block_res_samples += (down_block_res_sample,)
603603

604604
down_block_res_samples = new_down_block_res_samples
@@ -614,7 +614,7 @@ def forward(
614614
)
615615

616616
if mid_block_additional_residual is not None:
617-
sample += mid_block_additional_residual
617+
sample = sample + mid_block_additional_residual
618618

619619
# 5. up
620620
for i, upsample_block in enumerate(self.up_blocks):

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def forward(
688688
for down_block_res_sample, down_block_additional_residual in zip(
689689
down_block_res_samples, down_block_additional_residuals
690690
):
691-
down_block_res_sample += down_block_additional_residual
691+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
692692
new_down_block_res_samples += (down_block_res_sample,)
693693

694694
down_block_res_samples = new_down_block_res_samples
@@ -704,7 +704,7 @@ def forward(
704704
)
705705

706706
if mid_block_additional_residual is not None:
707-
sample += mid_block_additional_residual
707+
sample = sample + mid_block_additional_residual
708708

709709
# 5. up
710710
for i, upsample_block in enumerate(self.up_blocks):

0 commit comments

Comments
 (0)