Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

No description provided.

@patrickvonplaten
Copy link
Contributor Author

Will add tests and docstring later:

To verify that it works:

#!/usr/bin/env python3
from diffusers import FlaxAutoencoderKL, AutoencoderKL
import torch
import numpy as np
import jax.numpy as jnp

model_flax, params = FlaxAutoencoderKL.from_pretrained("fusing/sd-v1-4-flax", subfolder="vae", use_auth_token=True)
model_pt = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True)

sample_shape = (1, model_flax.config.in_channels, model_flax.config.sample_size, model_flax.config.sample_size)
sample = jnp.ones(sample_shape, dtype=jnp.float32)

sample_out_flax = model_flax.apply({"params": params}, sample)

sample_pt = torch.from_numpy(np.asarray(sample))

with torch.no_grad():
    sample_out_pt = model_pt(sample_pt)

max_diff = np.max(np.abs(np.asarray(sample_out_flax.sample) - sample_out_pt.sample.numpy()))
print("Max diff", max_diff)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 18, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I'd suggest to test it with pmap if you haven't done it already, to ensure it works in multiple devices in jitted mode as well as in a single device.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just left some nits to align the names with PyTorch.

Two things to be addressed before merging:

  • Use lists for downsamplers and unsamplers, same way we do in PT. This will make auto-conversion possible.
  • transpose the output in the decode method as that method is used in the pipeline.

For follow-up PR:

  • We should move the blocks in its own fileunet_blocks_flax.py
  • Make sure all module names are identical to it's PT counterpart.

@patrickvonplaten patrickvonplaten merged commit bf5ca03 into main Sep 19, 2022
@patrickvonplaten patrickvonplaten deleted the add_flax_vae branch September 19, 2022 14:00
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [Flax] Add Vae

* correct

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* Finish

Co-authored-by: Suraj Patil <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants