From 949096d89dbc97b519ebda8cc6aef4d2a106aa58 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 11:45:26 +0000 Subject: [PATCH 1/2] Fix invocation of some slow tests. We use __call__ rather than pmapping the generation function ourselves because the number of static arguments is different now. --- tests/test_pipelines_flax.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index a461930f3a83..e7c333cd36ef 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -70,14 +70,12 @@ def test_dummy_all_tpus(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 64, 64, 3) if jax.device_count() == 8: @@ -105,14 +103,12 @@ def test_stable_diffusion_v1_4(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: @@ -136,14 +132,12 @@ def test_stable_diffusion_v1_4_bfloat_16(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: @@ -211,14 +205,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: From 010cd16759077ef1e1a4537ed10e233e2552f39e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 19:08:54 +0000 Subject: [PATCH 2/2] style --- tests/test_pipelines_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index e7c333cd36ef..aab2eb9a07fb 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -28,7 +28,6 @@ import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard - from jax import pmap from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline