Skip to content

Commit df91c44

Browse files
yiyixuxuyiyixuxu
andauthored
Flax controlnet (#2727)
* add contronet flax --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent aa0531f commit df91c44

File tree

13 files changed

+1125
-2
lines changed

13 files changed

+1125
-2
lines changed

docs/source/en/api/models.mdx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
9999

100100
## FlaxAutoencoderKL
101101
[[autodoc]] FlaxAutoencoderKL
102+
103+
## FlaxControlNetOutput
104+
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
105+
106+
## FlaxControlNetModel
107+
[[autodoc]] FlaxControlNetModel

docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
272272
- disable_vae_slicing
273273
- enable_xformers_memory_efficient_attention
274274
- disable_xformers_memory_efficient_attention
275+
276+
## FlaxStableDiffusionControlNetPipeline
277+
[[autodoc]] FlaxStableDiffusionControlNetPipeline
278+
- all
279+
- __call__
280+

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@
188188
except OptionalDependencyNotAvailable:
189189
from .utils.dummy_flax_objects import * # noqa F403
190190
else:
191+
from .models.controlnet_flax import FlaxControlNetModel
191192
from .models.modeling_flax_utils import FlaxModelMixin
192193
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
193194
from .models.vae_flax import FlaxAutoencoderKL
@@ -211,6 +212,7 @@
211212
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
212213
else:
213214
from .pipelines import (
215+
FlaxStableDiffusionControlNetPipeline,
214216
FlaxStableDiffusionImg2ImgPipeline,
215217
FlaxStableDiffusionInpaintPipeline,
216218
FlaxStableDiffusionPipeline,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@
3030
from .vq_model import VQModel
3131

3232
if is_flax_available():
33+
from .controlnet_flax import FlaxControlNetModel
3334
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
3435
from .vae_flax import FlaxAutoencoderKL

0 commit comments

Comments
 (0)