Skip to content

Flax controlnet #2727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 23, 2023
Merged

Flax controlnet #2727

merged 13 commits into from
Mar 23, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Mar 17, 2023

to-do:

  • pass equivalency test
  • add pipeline
  • doc and testing

Example

import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers.utils import load_image
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel


def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)

# get canny image
canny_image = load_image("https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg")

prompts = "best quality, extremely detailed"
negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"

# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype = jnp.float32)
params['controlnet'] = controlnet_params

num_samples = jax.device_count()
rng = jax.random.split(rng, jax.device_count())

prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

output = pipe(
    prompt_ids=prompt_ids,
    image=processed_image,
    params=p_params,
    prng_seed=rng,
    num_inference_steps=50,
    neg_prompt_ids=negative_prompt_ids,
    jit=True,
).images

output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
output_images = image_grid(output_images, num_samples//4, 4)
output_images.save("generated_image.png")

generated_image

CPU equivalency test

import jax
import jax.numpy as jnp

import numpy as np
import torch

from diffusers import  ControlNetModel, FlaxControlNetModel, UNet2DConditionModel, FlaxUNet2DConditionModel

ndtype = np.float32
jdtype = jnp.float32
tdtype = torch.float32

def to_np(x):
  """Converts tensors to numpy."""
  return np.asarray(x, dtype=ndtype)

# create models (controlnet + unet)
control_repo = "lllyasviel/sd-controlnet-canny"
controlnet = ControlNetModel.from_pretrained(control_repo)
controlnet_flax, controlnet_params = FlaxControlNetModel.from_pretrained(control_repo, from_pt=True)

unet_repo = "runwayml/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(unet_repo, subfolder='unet')
unet_flax, unet_params = FlaxUNet2DConditionModel.from_pretrained(unet_repo, subfolder='unet', from_pt=True)

# create inputs (pytorch)
sample = torch.rand(1, 4, 64, 64, dtype =tdtype)
t = 999
encoder_hidden_states = torch.rand( 1,77,768, dtype=tdtype)
controlnet_cond = torch.rand(1, 3, 512, 512, dtype=tdtype)

# run ControlNet in pytorch 
down_res, mid = controlnet(sample,
                    t,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=controlnet_cond,
                    conditioning_scale=1.,
                    return_dict=False,
                )
noise = unet(
    sample, 
    t, 
    encoder_hidden_states, 
    down_block_additional_residuals = down_res,
    mid_block_additional_residual = mid,).sample

# convert pytorch input to jnp
sample_f =  np.asarray(sample, dtype=jdtype)
encoder_hidden_states_f = jnp.asarray(encoder_hidden_states, dtype=jdtype)
controlnet_cond_f = jnp.asarray(controlnet_cond , dtype=jdtype)

# run ControlNet in flax
down_res_f, mid_f = controlnet_flax.apply(
    {'params': controlnet_params},
    sample_f,
    t,
    encoder_hidden_states_f,
    controlnet_cond_f,
    return_dict=False,)

noise_f = unet_flax.apply(
    {'params': unet_params},
    sample_f, 
    t, 
    encoder_hidden_states_f, 
    down_block_additional_residuals = down_res_f,
    mid_block_additional_residual = mid_f,).sample

# compare
print(np.max(np.abs(to_np(noise.detach())- to_np(noise_f))))
2.002716e-05

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@yiyixuxu yiyixuxu marked this pull request as ready for review March 20, 2023 15:54
@patrickvonplaten
Copy link
Contributor

PR looks great! Happy to merge if Flax Stable Diffusion Tests are still all passing

@yiyixuxu
Copy link
Collaborator Author

cc @patrickvonplaten

weight initialization compare
here I only compare layers in ControlNetModel that are not initialized from the pre-trained unet during the training

The default initializer in pytorch conv2d vs flax Conv is different, hence the slight difference you see in controlnet_cond_embedding layer - I don't think we need to match pytorch initializer though.

import jax
from diffusers import FlaxControlNetModel, ControlNetModel 

from diffusers.models.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from flax.core.frozen_dict import freeze


config_f = FlaxControlNetModel.load_config("lllyasviel/sd-controlnet-canny")
model_f = FlaxControlNetModel.from_config(config_f)

params_f = model_f.init_weights(rng=jax.random.PRNGKey(0))


config = ControlNetModel.load_config("lllyasviel/sd-controlnet-canny")
model = ControlNetModel.from_config(config)
params = convert_pytorch_state_dict_to_flax(model.state_dict(), model_f)
params = freeze(params)

for key in ['controlnet_cond_embedding','controlnet_mid_block'] + [f'controlnet_down_blocks_{i}' for i in range(11)]:
    print(f'compare parameter (mean, variance) for {key}')
    print("=== pytorch ===")
    print(jax.tree_map(lambda x: (x.mean().item(), x.var().item()), params[key]))
    print("=== flax ===")
    print(jax.tree_map(lambda x: (x.mean().item(), x.var().item()), params_f[key]))
    print(' ')
compare parameter (mean, variance) for controlnet_cond_embedding
=== pytorch ===
FrozenDict({
    blocks_0: {
        bias: (-0.01611514575779438, 0.001917642424814403),
        kernel: (-0.0008296407759189606, 0.0023712010588496923),
    },
    blocks_1: {
        bias: (-0.025777339935302734, 0.0016312293009832501),
        kernel: (0.0005276202573440969, 0.002340401755645871),
    },
    blocks_2: {
        bias: (-0.001183301443234086, 0.001086806645616889),
        kernel: (-0.0006763663259334862, 0.0011519063264131546),
    },
    blocks_3: {
        bias: (0.000355269992724061, 0.0008704514475539327),
        kernel: (9.844478336162865e-05, 0.0011636927956715226),
    },
    blocks_4: {
        bias: (-8.27490002848208e-05, 0.0004041626525577158),
        kernel: (-8.368547423742712e-05, 0.0003846006002277136),
    },
    blocks_5: {
        bias: (-0.0031753063667565584, 0.0003698799409903586),
        kernel: (-5.317967588780448e-05, 0.0003860045690089464),
    },
    conv_in: {
        bias: (-0.0007939962670207024, 0.012068698182702065),
        kernel: (0.004656326025724411, 0.012428686022758484),
    },
    conv_out: {
        bias: (0.0, 0.0),
        kernel: (0.0, 0.0),
    },
})
=== flax ===
FrozenDict({
    blocks_0: {
        bias: (0.0, 0.0),
        kernel: (0.00046943812048994005, 0.00703172804787755),
    },
    blocks_1: {
        bias: (0.0, 0.0),
        kernel: (-0.00201827147975564, 0.006816760636866093),
    },
    blocks_2: {
        bias: (0.0, 0.0),
        kernel: (0.000466530880657956, 0.003474178956821561),
    },
    blocks_3: {
        bias: (0.0, 0.0),
        kernel: (8.83411048562266e-05, 0.0034741892013698816),
    },
    blocks_4: {
        bias: (0.0, 0.0),
        kernel: (2.4200153347919695e-05, 0.001159254345111549),
    },
    blocks_5: {
        bias: (0.0, 0.0),
        kernel: (5.74419173062779e-05, 0.0011549323098734021),
    },
    conv_in: {
        bias: (0.0, 0.0),
        kernel: (-0.004423078149557114, 0.03543131798505783),
    },
    conv_out: {
        bias: (0.0, 0.0),
        kernel: (0.0, 0.0),
    },
})
 
compare parameter (mean, variance) for controlnet_mid_block
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_0
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_1
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_2
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_3
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_4
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_5
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_6
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_7
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_8
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_9
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
 
compare parameter (mean, variance) for controlnet_down_blocks_10
=== pytorch ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})
=== flax ===
FrozenDict({
    bias: (0.0, 0.0),
    kernel: (0.0, 0.0),
})

@yiyixuxu
Copy link
Collaborator Author

confirm that all slow tests for Flax Stable Diffusion Tests passing on this branch, once I update the slices in main to make the tests pass on main @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Very nice! Feel free to merge :-)

@yiyixuxu yiyixuxu merged commit df91c44 into main Mar 23, 2023
@kashif kashif deleted the flax-controlnet branch March 24, 2023 10:03
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add contronet flax

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add contronet flax

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add contronet flax

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants