From 2160f327b371ea483ebf41bb5a84539fd4e0de3d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 1 Sep 2022 10:18:15 +0200 Subject: [PATCH 1/3] Use ONNX / Core ML compatible method to broadcast. Unfortunately `tile` could not be used either, it's still not compatible with ONNX. See #284. --- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db4c33690c9d..9e6ed97342c4 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -133,7 +133,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..fd9db307b360 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -134,7 +134,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) From 91aa34f48f1632b21d3b0e6e26a1a6c8d7a4df1a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 2 Sep 2022 09:27:36 +0200 Subject: [PATCH 2/3] Add comment about why broadcast_to is not used. Also, apply style to changed files. --- src/diffusers/models/unet_2d.py | 3 +-- src/diffusers/models/unet_2d_condition.py | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 9e6ed97342c4..14d6db70c720 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -120,7 +120,6 @@ def __init__( def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] ) -> Dict[str, torch.FloatTensor]: - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -132,7 +131,7 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) t_emb = self.time_proj(timesteps) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index fd9db307b360..92a2161f9eae 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -121,7 +121,6 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, ) -> Dict[str, torch.FloatTensor]: - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -133,7 +132,7 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) t_emb = self.time_proj(timesteps) @@ -145,7 +144,6 @@ def forward( # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states @@ -160,7 +158,6 @@ def forward( # 5. up for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] From 1f06c95cc27a73e08e3093adba40711db742ef74 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 2 Sep 2022 09:42:50 +0200 Subject: [PATCH 3/3] Make sure broadcast remains in same device. --- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 14d6db70c720..12a15dcb751b 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -132,7 +132,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 92a2161f9eae..acba923bf81a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -133,7 +133,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype) + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb)