Skip to content

Commit 3346ec3

Browse files
prathikrPrathik Raoanton-l
authored
integrate ort (#1110)
* integrate ort * use return_dict=False * revert unet return value change * revert unet return value change * add note to readme * adjust readme * add contact * `make style` Co-authored-by: Prathik Rao <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 61719bf commit 3346ec3

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

examples/unconditional_image_generation/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,24 @@ dataset.push_to_hub("name_of_your_dataset", private=True)
127127
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
128128

129129
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
130+
131+
#### Use ONNXRuntime to accelerate training
132+
133+
In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py
134+
135+
The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:
136+
137+
```bash
138+
accelerate launch train_unconditional_ort.py \
139+
--dataset_name="huggan/flowers-102-categories" \
140+
--resolution=64 \
141+
--output_dir="ddpm-ema-flowers-64" \
142+
--train_batch_size=16 \
143+
--num_epochs=1 \
144+
--gradient_accumulation_steps=1 \
145+
--learning_rate=1e-4 \
146+
--lr_warmup_steps=500 \
147+
--mixed_precision=fp16
148+
```
149+
150+
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import argparse
2+
import math
3+
import os
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from accelerate import Accelerator
9+
from accelerate.logging import get_logger
10+
from datasets import load_dataset
11+
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
12+
from diffusers.hub_utils import init_git_repo, push_to_hub
13+
from diffusers.optimization import get_scheduler
14+
from diffusers.training_utils import EMAModel
15+
from onnxruntime.training.ortmodule import ORTModule
16+
from torchvision.transforms import (
17+
CenterCrop,
18+
Compose,
19+
InterpolationMode,
20+
Normalize,
21+
RandomHorizontalFlip,
22+
Resize,
23+
ToTensor,
24+
)
25+
from tqdm.auto import tqdm
26+
27+
28+
logger = get_logger(__name__)
29+
30+
31+
def main(args):
32+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
33+
accelerator = Accelerator(
34+
gradient_accumulation_steps=args.gradient_accumulation_steps,
35+
mixed_precision=args.mixed_precision,
36+
log_with="tensorboard",
37+
logging_dir=logging_dir,
38+
)
39+
40+
model = UNet2DModel(
41+
sample_size=args.resolution,
42+
in_channels=3,
43+
out_channels=3,
44+
layers_per_block=2,
45+
block_out_channels=(128, 128, 256, 256, 512, 512),
46+
down_block_types=(
47+
"DownBlock2D",
48+
"DownBlock2D",
49+
"DownBlock2D",
50+
"DownBlock2D",
51+
"AttnDownBlock2D",
52+
"DownBlock2D",
53+
),
54+
up_block_types=(
55+
"UpBlock2D",
56+
"AttnUpBlock2D",
57+
"UpBlock2D",
58+
"UpBlock2D",
59+
"UpBlock2D",
60+
"UpBlock2D",
61+
),
62+
)
63+
model = ORTModule(model)
64+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
65+
optimizer = torch.optim.AdamW(
66+
model.parameters(),
67+
lr=args.learning_rate,
68+
betas=(args.adam_beta1, args.adam_beta2),
69+
weight_decay=args.adam_weight_decay,
70+
eps=args.adam_epsilon,
71+
)
72+
73+
augmentations = Compose(
74+
[
75+
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
76+
CenterCrop(args.resolution),
77+
RandomHorizontalFlip(),
78+
ToTensor(),
79+
Normalize([0.5], [0.5]),
80+
]
81+
)
82+
83+
if args.dataset_name is not None:
84+
dataset = load_dataset(
85+
args.dataset_name,
86+
args.dataset_config_name,
87+
cache_dir=args.cache_dir,
88+
use_auth_token=True if args.use_auth_token else None,
89+
split="train",
90+
)
91+
else:
92+
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
93+
94+
def transforms(examples):
95+
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
96+
return {"input": images}
97+
98+
dataset.set_transform(transforms)
99+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
100+
101+
lr_scheduler = get_scheduler(
102+
args.lr_scheduler,
103+
optimizer=optimizer,
104+
num_warmup_steps=args.lr_warmup_steps,
105+
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
106+
)
107+
108+
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
109+
model, optimizer, train_dataloader, lr_scheduler
110+
)
111+
112+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
113+
114+
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
115+
116+
if args.push_to_hub:
117+
repo = init_git_repo(args, at_init=True)
118+
119+
if accelerator.is_main_process:
120+
run = os.path.split(__file__)[-1].split(".")[0]
121+
accelerator.init_trackers(run)
122+
123+
global_step = 0
124+
for epoch in range(args.num_epochs):
125+
model.train()
126+
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
127+
progress_bar.set_description(f"Epoch {epoch}")
128+
for step, batch in enumerate(train_dataloader):
129+
clean_images = batch["input"]
130+
# Sample noise that we'll add to the images
131+
noise = torch.randn(clean_images.shape).to(clean_images.device)
132+
bsz = clean_images.shape[0]
133+
# Sample a random timestep for each image
134+
timesteps = torch.randint(
135+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
136+
).long()
137+
138+
# Add noise to the clean images according to the noise magnitude at each timestep
139+
# (this is the forward diffusion process)
140+
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
141+
142+
with accelerator.accumulate(model):
143+
# Predict the noise residual
144+
noise_pred = model(noisy_images, timesteps, return_dict=True)[0]
145+
loss = F.mse_loss(noise_pred, noise)
146+
accelerator.backward(loss)
147+
148+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
149+
optimizer.step()
150+
lr_scheduler.step()
151+
if args.use_ema:
152+
ema_model.step(model)
153+
optimizer.zero_grad()
154+
155+
# Checks if the accelerator has performed an optimization step behind the scenes
156+
if accelerator.sync_gradients:
157+
progress_bar.update(1)
158+
global_step += 1
159+
160+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
161+
if args.use_ema:
162+
logs["ema_decay"] = ema_model.decay
163+
progress_bar.set_postfix(**logs)
164+
accelerator.log(logs, step=global_step)
165+
progress_bar.close()
166+
167+
accelerator.wait_for_everyone()
168+
169+
# Generate sample images for visual inspection
170+
if accelerator.is_main_process:
171+
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
172+
pipeline = DDPMPipeline(
173+
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
174+
scheduler=noise_scheduler,
175+
)
176+
177+
generator = torch.manual_seed(0)
178+
# run pipeline in inference (sample random noise and denoise)
179+
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
180+
181+
# denormalize the images and save to tensorboard
182+
images_processed = (images * 255).round().astype("uint8")
183+
accelerator.trackers[0].writer.add_images(
184+
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
185+
)
186+
187+
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
188+
# save the model
189+
if args.push_to_hub:
190+
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
191+
else:
192+
pipeline.save_pretrained(args.output_dir)
193+
accelerator.wait_for_everyone()
194+
195+
accelerator.end_training()
196+
197+
198+
if __name__ == "__main__":
199+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
200+
parser.add_argument("--local_rank", type=int, default=-1)
201+
parser.add_argument("--dataset_name", type=str, default=None)
202+
parser.add_argument("--dataset_config_name", type=str, default=None)
203+
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
204+
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
205+
parser.add_argument("--overwrite_output_dir", action="store_true")
206+
parser.add_argument("--cache_dir", type=str, default=None)
207+
parser.add_argument("--resolution", type=int, default=64)
208+
parser.add_argument("--train_batch_size", type=int, default=16)
209+
parser.add_argument("--eval_batch_size", type=int, default=16)
210+
parser.add_argument("--num_epochs", type=int, default=100)
211+
parser.add_argument("--save_images_epochs", type=int, default=10)
212+
parser.add_argument("--save_model_epochs", type=int, default=10)
213+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
214+
parser.add_argument("--learning_rate", type=float, default=1e-4)
215+
parser.add_argument("--lr_scheduler", type=str, default="cosine")
216+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
217+
parser.add_argument("--adam_beta1", type=float, default=0.95)
218+
parser.add_argument("--adam_beta2", type=float, default=0.999)
219+
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
220+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
221+
parser.add_argument("--use_ema", action="store_true", default=True)
222+
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
223+
parser.add_argument("--ema_power", type=float, default=3 / 4)
224+
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
225+
parser.add_argument("--push_to_hub", action="store_true")
226+
parser.add_argument("--use_auth_token", action="store_true")
227+
parser.add_argument("--hub_token", type=str, default=None)
228+
parser.add_argument("--hub_model_id", type=str, default=None)
229+
parser.add_argument("--hub_private_repo", action="store_true")
230+
parser.add_argument("--logging_dir", type=str, default="logs")
231+
parser.add_argument(
232+
"--mixed_precision",
233+
type=str,
234+
default="no",
235+
choices=["no", "fp16", "bf16"],
236+
help=(
237+
"Whether to use mixed precision. Choose"
238+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
239+
"and an Nvidia Ampere GPU."
240+
),
241+
)
242+
243+
args = parser.parse_args()
244+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
245+
if env_local_rank != -1 and env_local_rank != args.local_rank:
246+
args.local_rank = env_local_rank
247+
248+
if args.dataset_name is None and args.train_data_dir is None:
249+
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
250+
251+
main(args)

src/diffusers/pipeline_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@
7878
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
7979
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
8080
},
81+
"onnxruntime.training": {
82+
"ORTModule": ["save_pretrained", "from_pretrained"],
83+
},
8184
}
8285

8386
ALL_IMPORTABLE_CLASSES = {}

0 commit comments

Comments
 (0)