diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 3c23693b40ef..e96c0c7467f3 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -444,7 +444,11 @@ def numpy_to_pil(images): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index c9c58a748831..c0a44363a2f9 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -625,7 +625,11 @@ def numpy_to_pil(images): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images