-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Flax] Add Vae for Stable Diffusion #555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Will add tests and docstring later: To verify that it works: #!/usr/bin/env python3
from diffusers import FlaxAutoencoderKL, AutoencoderKL
import torch
import numpy as np
import jax.numpy as jnp
model_flax, params = FlaxAutoencoderKL.from_pretrained("fusing/sd-v1-4-flax", subfolder="vae", use_auth_token=True)
model_pt = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True)
sample_shape = (1, model_flax.config.in_channels, model_flax.config.sample_size, model_flax.config.sample_size)
sample = jnp.ones(sample_shape, dtype=jnp.float32)
sample_out_flax = model_flax.apply({"params": params}, sample)
sample_pt = torch.from_numpy(np.asarray(sample))
with torch.no_grad():
sample_out_pt = model_pt(sample_pt)
max_diff = np.max(np.abs(np.asarray(sample_out_flax.sample) - sample_out_pt.sample.numpy()))
print("Max diff", max_diff) |
|
The documentation is not available anymore as the PR was closed or merged. |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. I'd suggest to test it with pmap if you haven't done it already, to ensure it works in multiple devices in jitted mode as well as in a single device.
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just left some nits to align the names with PyTorch.
Two things to be addressed before merging:
- Use lists for downsamplers and unsamplers, same way we do in PT. This will make auto-conversion possible.
transposetheoutputin thedecodemethod as that method is used in the pipeline.
For follow-up PR:
- We should move the blocks in its own file
unet_blocks_flax.py - Make sure all module names are identical to it's PT counterpart.
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: dan <[email protected]>
* [Flax] Add Vae * correct * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Finish Co-authored-by: Suraj Patil <[email protected]>
No description provided.