-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Align PT and Flax API - allow loading checkpoint from PyTorch configs #827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1a4c2e3
28e42f7
cf2f202
15ce104
8a15889
b29b554
39965f1
c3127e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
|
|
||
| if is_flax_available(): | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from diffusers import FlaxStableDiffusionPipeline | ||
| from flax.jax_utils import replicate | ||
| from flax.training.common_utils import shard | ||
|
|
@@ -34,7 +35,7 @@ | |
| class FlaxPipelineTests(unittest.TestCase): | ||
| def test_dummy_all_tpus(self): | ||
| pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
| "hf-internal-testing/tiny-stable-diffusion-pipe" | ||
| "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None | ||
| ) | ||
|
|
||
| prompt = ( | ||
|
|
@@ -57,6 +58,103 @@ def test_dummy_all_tpus(self): | |
| prompt_ids = shard(prompt_ids) | ||
|
|
||
| images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images | ||
|
|
||
| assert images.shape == (8, 1, 64, 64, 3) | ||
| assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3 | ||
| assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2 | ||
|
|
||
| images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | ||
|
|
||
| assert len(images_pil) == 8 | ||
|
|
||
| def test_stable_diffusion_v1_4(self): | ||
| pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
| "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None | ||
| ) | ||
|
|
||
| prompt = ( | ||
| "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" | ||
| " field, close up, split lighting, cinematic" | ||
| ) | ||
|
|
||
| prng_seed = jax.random.PRNGKey(0) | ||
| num_inference_steps = 50 | ||
|
|
||
| num_samples = jax.device_count() | ||
| prompt = num_samples * [prompt] | ||
| prompt_ids = pipeline.prepare_inputs(prompt) | ||
|
|
||
| p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) | ||
|
|
||
| # shard inputs and rng | ||
| params = replicate(params) | ||
| prng_seed = jax.random.split(prng_seed, 8) | ||
| prompt_ids = shard(prompt_ids) | ||
|
|
||
| images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images | ||
|
|
||
| images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | ||
| for i, image in enumerate(images_pil): | ||
| image.save(f"/home/patrick/images/flax-test-{i}_fp32.png") | ||
|
|
||
| assert images.shape == (8, 1, 512, 512, 3) | ||
| assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3 | ||
| assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2 | ||
|
|
||
| def test_stable_diffusion_v1_4_bfloat_16(self): | ||
| pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
| "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None | ||
| ) | ||
|
|
||
| prompt = ( | ||
| "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" | ||
| " field, close up, split lighting, cinematic" | ||
| ) | ||
|
|
||
| prng_seed = jax.random.PRNGKey(0) | ||
| num_inference_steps = 50 | ||
|
|
||
| num_samples = jax.device_count() | ||
| prompt = num_samples * [prompt] | ||
| prompt_ids = pipeline.prepare_inputs(prompt) | ||
|
|
||
| p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) | ||
|
|
||
| # shard inputs and rng | ||
| params = replicate(params) | ||
| prng_seed = jax.random.split(prng_seed, 8) | ||
| prompt_ids = shard(prompt_ids) | ||
|
|
||
| images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images | ||
|
|
||
| assert images.shape == (8, 1, 512, 512, 3) | ||
| assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 | ||
| assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2 | ||
|
|
||
| def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with safety + inner pmap |
||
| pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
| "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16 | ||
| ) | ||
|
|
||
| prompt = ( | ||
| "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" | ||
| " field, close up, split lighting, cinematic" | ||
| ) | ||
|
|
||
| prng_seed = jax.random.PRNGKey(0) | ||
| num_inference_steps = 50 | ||
|
|
||
| num_samples = jax.device_count() | ||
| prompt = num_samples * [prompt] | ||
| prompt_ids = pipeline.prepare_inputs(prompt) | ||
|
|
||
| # shard inputs and rng | ||
| params = replicate(params) | ||
| prng_seed = jax.random.split(prng_seed, 8) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to replace
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 agree
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this would indeed be a good idea (in case someone wants to open a PR for it, please feel free) |
||
| prompt_ids = shard(prompt_ids) | ||
|
|
||
| images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | ||
|
|
||
| assert images.shape == (8, 1, 512, 512, 3) | ||
| assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 | ||
| assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no safety + outer pmap