diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index fbd78b512a6b..7bb5416adf24 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -48,6 +48,10 @@ def forward(self, hidden_states, output_size=None): if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: @@ -376,6 +380,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: