diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f6affe8a1400..c645d30be7f7 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -59,6 +59,11 @@ def main(args): "UpBlock2D", ), ) + + if args.ort: + from onnxruntime.training import ORTModule + model = ORTModule(model) + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") optimizer = torch.optim.AdamW( model.parameters(), @@ -139,7 +144,10 @@ def transforms(examples): with accelerator.accumulate(model): # Predict the noise residual - noise_pred = model(noisy_images, timesteps).sample + if args.ort: + noise_pred = model(noisy_images, timesteps, return_dict=False)[0] + else: + noise_pred = model(noisy_images, timesteps).sample loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) @@ -237,6 +245,7 @@ def transforms(examples): "and an Nvidia Ampere GPU." ), ) + parser.add_argument("--ort", action="store_true") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index be3429e26ac5..c4acaedb2f9e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -53,6 +53,9 @@ "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + } } ALL_IMPORTABLE_CLASSES = {}