Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,24 @@ def forward(self, hidden_states, output_size=None):
if self.use_conv_transpose:
return self.conv(hidden_states)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)

# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def __call__(
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)

special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()

result = []
batch_size = image_embeds.shape[0]
Expand Down