-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Flax controlnet #2727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax controlnet #2727
Conversation
The documentation is not available anymore as the PR was closed or merged. |
PR looks great! Happy to merge if Flax Stable Diffusion Tests are still all passing |
weight initialization compare The default initializer in pytorch import jax
from diffusers import FlaxControlNetModel, ControlNetModel
from diffusers.models.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from flax.core.frozen_dict import freeze
config_f = FlaxControlNetModel.load_config("lllyasviel/sd-controlnet-canny")
model_f = FlaxControlNetModel.from_config(config_f)
params_f = model_f.init_weights(rng=jax.random.PRNGKey(0))
config = ControlNetModel.load_config("lllyasviel/sd-controlnet-canny")
model = ControlNetModel.from_config(config)
params = convert_pytorch_state_dict_to_flax(model.state_dict(), model_f)
params = freeze(params)
for key in ['controlnet_cond_embedding','controlnet_mid_block'] + [f'controlnet_down_blocks_{i}' for i in range(11)]:
print(f'compare parameter (mean, variance) for {key}')
print("=== pytorch ===")
print(jax.tree_map(lambda x: (x.mean().item(), x.var().item()), params[key]))
print("=== flax ===")
print(jax.tree_map(lambda x: (x.mean().item(), x.var().item()), params_f[key]))
print(' ')
|
confirm that all slow tests for Flax Stable Diffusion Tests passing on this branch, once I update the slices in main to make the tests pass on main @patrickvonplaten |
Very nice! Feel free to merge :-) |
* add contronet flax --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
* add contronet flax --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
* add contronet flax --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
to-do:
Example
CPU equivalency test