diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index b90737892815..20228582e9e2 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -14,6 +14,7 @@ # limitations under the License. """ Conversion script for the LDM checkpoints. """ +import torch import argparse from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt @@ -123,6 +124,7 @@ parser.add_argument( "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." ) + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() pipe = download_from_original_stable_diffusion_ckpt( @@ -143,6 +145,9 @@ controlnet=args.controlnet, ) + if args.half: + pipe.to(torch_dtype=torch.float16) + if args.controlnet: # only save the controlnet model pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)