diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 84493b1d9484..078a66e4acee 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -452,7 +452,9 @@ def collate_fn(examples): weight_dtype = jnp.bfloat16 # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype ) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index cacfacef498b..89a8dec7289e 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -379,7 +379,9 @@ def collate_fn(examples): # Load models and create wrapper for stable diffusion tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 84ff97c39a96..be2b7ffb5490 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -391,7 +391,7 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")