Skip to content

Commit a276639

Browse files
committed
Fix preprocess images
1 parent 9d10981 commit a276639

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,8 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image
120120

121121
if isinstance(image, Image.Image):
122122
image = [image]
123-
processed_image = []
124-
for img in image:
125-
processed_image.append(preprocess(img, self.dtype))
126-
processed_image = jnp.array(processed_image).squeeze()
123+
124+
processed_images = jnp.array([preprocess(img, jnp.float32) for img in image])
127125

128126
text_input = self.tokenizer(
129127
prompt,
@@ -132,7 +130,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image
132130
truncation=True,
133131
return_tensors="np",
134132
)
135-
return text_input.input_ids, processed_image
133+
return text_input.input_ids, processed_images
136134

137135
def _get_has_nsfw_concepts(self, features, params):
138136
has_nsfw_concepts = self.safety_checker(features, params)

0 commit comments

Comments
 (0)