Skip to content

Flax controlnet #2727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module

## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL

## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput

## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

## FlaxStableDiffusionControlNetPipeline
[[autodoc]] FlaxStableDiffusionControlNetPipeline
- all
- __call__

2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
Expand All @@ -199,6 +200,7 @@
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
FlaxStableDiffusionControlNetPipeline,
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from .vq_model import VQModel

if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
Loading