diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..cc514d26781e 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -181,17 +181,30 @@ def __init__( resnet_groups=norm_num_groups, add_attention=add_attention, ) - + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - is_final_block = i == len(block_out_channels) - 1 - + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, @@ -199,7 +212,7 @@ def __init__( out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, - add_upsample=not is_final_block, + add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, @@ -235,6 +248,25 @@ def forward( [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + + + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + forward_upsample_size = True + + + + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -288,14 +320,21 @@ def forward( # 5. up skip_sample = None - for upsample_block in self.up_blocks: + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample,upsample_size=upsample_size) else: - sample = upsample_block(sample, res_samples, emb) + sample = upsample_block(sample, res_samples, emb,upsample_size=upsample_size) # 6. post-process sample = self.conv_norm_out(sample) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..38d586ecdf04 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1673,7 +1673,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None,upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1685,7 +1685,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states,upsample_size) return hidden_states