Skip to content

Commit 0671472

Browse files
camenduruPrathik Rao
authored andcommitted
Fix Flax pipeline: width and height are ignored huggingface#838 (huggingface#848)
* Fix Flax pipeline: width and height are ignored huggingface#838 * Fix Flax pipeline: width and height are ignored
1 parent 8e4792f commit 0671472

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,7 @@ def _generate(
152152
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
153153
context = jnp.concatenate([uncond_embeddings, text_embeddings])
154154

155-
latents_shape = (
156-
batch_size,
157-
self.unet.in_channels,
158-
self.unet.sample_size,
159-
self.unet.sample_size,
160-
)
155+
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
161156
if latents is None:
162157
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
163158
else:

0 commit comments

Comments
 (0)