Skip to content

Commit bf5ca03

Browse files
[Flax] Add Vae for Stable Diffusion (#555)
* [Flax] Add Vae * correct * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Finish Co-authored-by: Suraj Patil <[email protected]>
1 parent b17d49f commit bf5ca03

File tree

3 files changed

+610
-0
lines changed

3 files changed

+610
-0
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
if is_flax_available():
6666
from .modeling_flax_utils import FlaxModelMixin
6767
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
68+
from .models.vae_flax import FlaxAutoencoderKL
6869
from .schedulers import (
6970
FlaxDDIMScheduler,
7071
FlaxDDPMScheduler,

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def from_pretrained(
294294
local_files_only=local_files_only,
295295
use_auth_token=use_auth_token,
296296
revision=revision,
297+
subfolder=subfolder,
297298
# model args
298299
dtype=dtype,
299300
**kwargs,

0 commit comments

Comments
 (0)