From 838ca9ad7ec1755fcb8f86df39b6cccfcb6daa8d Mon Sep 17 00:00:00 2001 From: duongna21 Date: Sat, 5 Nov 2022 22:40:22 +0700 Subject: [PATCH] load text encoder from subfolder --- examples/dreambooth/train_dreambooth_flax.py | 4 +++- examples/text_to_image/train_text_to_image_flax.py | 4 +++- examples/textual_inversion/textual_inversion_flax.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 84493b1d9484..078a66e4acee 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -452,7 +452,9 @@ def collate_fn(examples): weight_dtype = jnp.bfloat16 # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype ) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index cacfacef498b..89a8dec7289e 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -379,7 +379,9 @@ def collate_fn(examples): # Load models and create wrapper for stable diffusion tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 84ff97c39a96..be2b7ffb5490 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -391,7 +391,7 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")