Skip to content

Commit 139f707

Browse files
authored
Correction for non-integral image resolutions with quantizations other than float32 (#7356)
* Correction for non-integral image resolutions with quantizations other than float32. * Support for training, and use of diffusers-style casting.
1 parent e4546fd commit 139f707

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/diffusers/models/unets/unet_stable_cascade.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,11 @@ def custom_forward(*inputs):
521521
if isinstance(block, SDCascadeResBlock):
522522
skip = level_outputs[i] if k == 0 and i > 0 else None
523523
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
524+
orig_type = x.dtype
524525
x = torch.nn.functional.interpolate(
525526
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
526527
)
528+
x = x.to(orig_type)
527529
x = torch.utils.checkpoint.checkpoint(
528530
create_custom_forward(block), x, skip, use_reentrant=False
529531
)
@@ -547,9 +549,11 @@ def custom_forward(*inputs):
547549
if isinstance(block, SDCascadeResBlock):
548550
skip = level_outputs[i] if k == 0 and i > 0 else None
549551
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
552+
orig_type = x.dtype
550553
x = torch.nn.functional.interpolate(
551554
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
552555
)
556+
x = x.to(orig_type)
553557
x = block(x, skip)
554558
elif isinstance(block, SDCascadeAttnBlock):
555559
x = block(x, clip)

0 commit comments

Comments
 (0)