diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db4c33690c9d..12a15dcb751b 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -120,7 +120,6 @@ def __init__( def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] ) -> Dict[str, torch.FloatTensor]: - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -132,8 +131,8 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..acba923bf81a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -121,7 +121,6 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, ) -> Dict[str, torch.FloatTensor]: - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -133,8 +132,8 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) @@ -145,7 +144,6 @@ def forward( # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states @@ -160,7 +158,6 @@ def forward( # 5. up for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]