Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
227 commits
Select commit Hold shift + click to select a range
ae151ff
wrapped model in ORTModule
Sep 22, 2022
0624ebd
bug fix
Sep 22, 2022
7a68077
bug fixes
Oct 19, 2022
62c0493
formatting
Oct 19, 2022
d237f0f
bug fix
Oct 20, 2022
cc09d21
remove commented blocks
Oct 20, 2022
3216e7e
bug fix
Oct 20, 2022
9ca7106
bug fix
Oct 20, 2022
eafdeec
formatting
Oct 20, 2022
6bd9120
add onnxruntime.training
Oct 20, 2022
051aa83
Merge branch 'main' into prathikrao/ort-integration
prathikr Oct 20, 2022
abe1637
docs: `.md` readability fixups (#619)
ryanrussell Sep 23, 2022
f6d8eb9
Flax documentation (#589)
younesbelkada Sep 23, 2022
78ae5a2
fix docs: change sample to images (#613)
AbdullahAlfaraj Sep 23, 2022
59b8506
refactor: pipelines readability improvements (#622)
ryanrussell Sep 23, 2022
f80643e
Allow passing session_options for ORT backend (#620)
cloudhan Sep 23, 2022
dfcc372
Fix breaking error: "ort is not defined" (#626)
pcuenca Sep 23, 2022
835279e
docs: `src/diffusers` readability improvements (#629)
ryanrussell Sep 24, 2022
9bd5f62
Fix formula for noise levels in Karras scheduler and tests (#627)
sgrigory Sep 24, 2022
c4619c0
[CI] Fix onnxruntime installation order (#633)
anton-l Sep 24, 2022
f610439
Warning for too long prompts in DiffusionPipelines (Resolve #447) (#472)
shirayu Sep 27, 2022
dea0c7b
Fix docs link to train_unconditional.py (#642)
AbdullahAlfaraj Sep 27, 2022
c9a729d
Remove deprecated `torch_device` kwarg (#623)
pcuenca Sep 27, 2022
f8c99ef
refactor: `custom_init_isort` readability fixups (#631)
ryanrussell Sep 27, 2022
b43ed93
Remove inappropriate docstrings in LMS docstrings. (#634)
pcuenca Sep 27, 2022
fff1416
Flax pipeline pndm (#583)
pcuenca Sep 27, 2022
ee622c0
Fix `SpatialTransformer` (#578)
ydshieh Sep 27, 2022
7ea5cf0
Add training example for DreamBooth. (#554)
Victarry Sep 27, 2022
8187550
rebase off main
Oct 26, 2022
2160283
[examples/dreambooth] don't pass tensor_format to scheduler. (#649)
patil-suraj Sep 27, 2022
dd626cb
[dreambooth] update install section (#650)
patil-suraj Sep 27, 2022
24902f8
[DDIM, DDPM] fix add_noise (#648)
patil-suraj Sep 27, 2022
57a861a
[Pytorch] add dep. warning for pytorch schedulers (#651)
kashif Sep 27, 2022
6dacef0
[CLIPGuidedStableDiffusion] remove set_format from pipeline (#653)
patil-suraj Sep 27, 2022
8b80e61
Fix onnx tensor format (#654)
anton-l Sep 27, 2022
ae850af
Fix `main`: stable diffusion pipelines cannot be loaded (#655)
pcuenca Sep 27, 2022
7659455
Fix the LMS pytorch regression (#664)
anton-l Sep 28, 2022
16b165d
Added script to save during textual inversion training. Issue 524 (#645)
isamu-isozaki Sep 28, 2022
50a41eb
[CLIPGuidedStableDiffusion] take the correct text embeddings (#667)
patil-suraj Sep 28, 2022
67d78e2
Update index.mdx (#670)
tmabraham Sep 29, 2022
a9e135d
[examples] update transfomers version (#665)
patil-suraj Sep 29, 2022
970c912
[gradient checkpointing] lower tolerance for test (#652)
patil-suraj Sep 29, 2022
6f8f7b2
Flax `from_pretrained`: clean up `mismatched_keys`. (#630)
pcuenca Sep 29, 2022
b898718
`trained_betas` ignored in some schedulers (#635)
vishnu-anirudh Sep 29, 2022
3517a4e
Renamed x -> hidden_states in resnet.py (#676)
daspartho Sep 29, 2022
8c3b76d
Optimize Stable Diffusion (#371)
NouamaneTazi Sep 30, 2022
422fef5
Allow resolutions that are not multiples of 64 (#505)
jachiam Sep 30, 2022
23514cf
refactor: update ldm-bert `config.json` url closes #675 (#680)
ryanrussell Sep 30, 2022
89e1149
[docs] fix table in fp16.mdx (#683)
NouamaneTazi Sep 30, 2022
1c55b0d
Update README.md
patrickvonplaten Sep 30, 2022
bf758b6
Update README.md
patrickvonplaten Sep 30, 2022
a8b1edb
Fix slow tests (#689)
NouamaneTazi Sep 30, 2022
819b573
Fix BibText citation (#693)
osanseviero Oct 1, 2022
cf83856
Add callback parameters for Stable Diffusion pipelines (#521)
jamestiotio Oct 2, 2022
47c3773
[dreambooth] fix applying clip_grad_norm_ (#686)
patil-suraj Oct 3, 2022
a75419a
Forgot to add the OG!
patrickvonplaten Oct 3, 2022
ec4f665
Flax: add shape argument to `set_timesteps` (#690)
pcuenca Oct 3, 2022
b079483
Fix type annotations on StableDiffusionPipeline.__call__ (#682)
tasercake Oct 3, 2022
65a1175
Fix import with Flax but without PyTorch (#688)
pcuenca Oct 3, 2022
9030eb1
[Support PyTorch 1.8] Remove inference mode (#707)
patrickvonplaten Oct 3, 2022
a8dc58f
[CI] Speed up slow tests (#708)
anton-l Oct 3, 2022
a3e21e9
[Utils] Add deprecate function and move testing_utils under utils (#659)
patrickvonplaten Oct 3, 2022
36f7ef5
Checkpoint conversion script from Diffusers => Stable Diffusion (Comp…
jachiam Oct 4, 2022
d026acd
[Docs] fix docstring for issue #709 (#710)
kashif Oct 4, 2022
a3a13fd
Update schedulers README.md (#694)
tmabraham Oct 4, 2022
b5c71f1
add accelerate to load models with smaller memory footprint (#361)
piEsposito Oct 4, 2022
0016774
Fix typos (#718)
shirayu Oct 4, 2022
3d05177
Add an argument "negative_prompt" (#549)
shirayu Oct 4, 2022
c75e492
Fix import if PyTorch is not installed (#715)
pcuenca Oct 4, 2022
b9eafed
Remove comments no longer appropriate (#716)
pcuenca Oct 4, 2022
e5152d4
[train_unconditional] fix applying clip_grad_norm_ (#721)
patil-suraj Oct 4, 2022
dae2931
renamed x to meaningful variable in resnet.py (#677)
i-am-epic Oct 4, 2022
439cedb
[Tests] Add accelerate to testing (#729)
patrickvonplaten Oct 5, 2022
5a3bbb6
[dreambooth] Using already created `Path` in dataset (#681)
DrInfiniteExplorer Oct 5, 2022
97c1c1c
Include CLIPTextModel parameters in conversion (#695)
kanewallmann Oct 5, 2022
f71023c
Avoid negative strides for tensors (#717)
shirayu Oct 5, 2022
032196b
[Pytorch] pytorch only timesteps (#724)
kashif Oct 5, 2022
4d342f3
[Scheduler design] The pragmatic approach (#719)
anton-l Oct 5, 2022
54f58c5
Removing `autocast` for `35-25% speedup`. (`autocast` considered harm…
Narsil Oct 5, 2022
23978d1
No more use_auth_token=True (#733)
patrickvonplaten Oct 5, 2022
4c084b8
remove use_auth_token from remaining places (#737)
patil-suraj Oct 5, 2022
6a28ab3
Replace messages that have empty backquotes (#738)
pcuenca Oct 5, 2022
56613f9
[Docs] Advertise fp16 instead of autocast (#740)
patrickvonplaten Oct 5, 2022
9170ad6
make style
patrickvonplaten Oct 5, 2022
76b1e5d
remove use_auth_token from for TI test (#747)
patil-suraj Oct 6, 2022
72a52f3
allow multiple generations per prompt (#741)
patil-suraj Oct 6, 2022
e910006
Add back-compatibility to LMS timesteps (#750)
anton-l Oct 6, 2022
cf085e1
update the clip guided PR according to the new API (#751)
patil-suraj Oct 6, 2022
54dc5b5
Raise an error when moving an fp16 pipeline to CPU (#749)
anton-l Oct 6, 2022
7f4abe9
Better steps deprecation for LMS (#753)
anton-l Oct 6, 2022
5c6b46c
Actually fix the grad ckpt test (#734)
patil-suraj Oct 6, 2022
c8828d7
Custome Pipelines (#744)
patrickvonplaten Oct 6, 2022
db5f4a3
make CI happy
patrickvonplaten Oct 6, 2022
b37e313
Python 3.7 doesn't like keys() + keys()
patrickvonplaten Oct 6, 2022
5ddc621
[v0.4.0] Temporarily remove Flax modules from the public API (#755)
anton-l Oct 6, 2022
d7abf97
Release: v0.4.0
anton-l Oct 6, 2022
9418ce4
Update clip_guided_stable_diffusion.py
patil-suraj Oct 6, 2022
609dce7
Bump to v0.4.1.dev0
anton-l Oct 6, 2022
ff75b77
Revert "[v0.4.0] Temporarily remove Flax modules from the public API …
anton-l Oct 6, 2022
a931b6b
Bump to v0.5.0.dev0
anton-l Oct 6, 2022
6de1401
Update clip_guided_stable_diffusion.py
patil-suraj Oct 6, 2022
6b88f0e
Created using Colaboratory
patil-suraj Oct 6, 2022
1457eb8
[Tests] Lower required memory for clip guided and fix super edge-case…
patrickvonplaten Oct 6, 2022
2d7d98e
Revert "Bump to v0.5.0.dev0"
anton-l Oct 6, 2022
e66d4aa
Change fp16 error to warning (#764)
apolinario Oct 7, 2022
0caaadc
Release: v0.4.1
patrickvonplaten Oct 7, 2022
b4a439b
Bump to v0.5.0dev0
patrickvonplaten Oct 7, 2022
3ad166c
remove bogus folder
patrickvonplaten Oct 7, 2022
81d39a5
remove bogus folder no.2
patrickvonplaten Oct 7, 2022
9f5ecb6
Fix push_to_hub for dreambooth and textual_inversion (#748)
YaYaB Oct 7, 2022
87ffde4
Fix ONNX conversion script opset argument type (#739)
justinchuby Oct 7, 2022
fec34cf
Add final latent slice checks to SD pipeline intermediate state tests…
jamestiotio Oct 7, 2022
31aa9df
fix(DDIM scheduler): use correct dtype for noise (#742)
keturn Oct 7, 2022
8275695
[schedulers] hanlde dtype in add_noise (#767)
patil-suraj Oct 7, 2022
346a163
[img2img, inpainting] fix fp16 inference (#769)
patil-suraj Oct 7, 2022
20d7ddc
[Tests] Fix tests (#774)
patrickvonplaten Oct 7, 2022
afb4294
debug an exception (#638)
LowinLi Oct 10, 2022
455689a
Clean up resnet.py file (#780)
Oct 10, 2022
7e39045
add sigmoid betas (#777)
Oct 10, 2022
da92a17
[Low CPU memory] + device map (#772)
patrickvonplaten Oct 10, 2022
a9517a6
Fix gradient checkpointing test (#797)
patrickvonplaten Oct 10, 2022
ea02fcb
fix typo docstring in unet2d (#798)
Oct 10, 2022
2d0758f
DreamBooth DeepSpeed support for under 8 GB VRAM training (#735)
Ttl Oct 10, 2022
57aacd4
support bf16 for stable diffusion (#792)
patil-suraj Oct 11, 2022
2d54b72
stable diffusion fine-tuning (#356)
patil-suraj Oct 11, 2022
97a44d7
Flax: Trickle down `norm_num_groups` (#789)
akash5474 Oct 11, 2022
182c8c0
Eventually preserve this typo? :) (#804)
spezialspezial Oct 11, 2022
fd07a1e
Fix indentation in the code example (#802)
osanseviero Oct 11, 2022
67b8ad1
`mps`: Alternative implementation for `repeat_interleave` (#766)
pcuenca Oct 11, 2022
ac50fe1
Update img2img.mdx
patrickvonplaten Oct 11, 2022
16fed0e
Add diffusers version and pipeline class to the Hub UA
anton-l Oct 12, 2022
833b581
[Img2Img] Fix batch size mismatch prompts vs. init images (#793)
patrickvonplaten Oct 12, 2022
cba1451
Minor package fixes (#809)
anton-l Oct 12, 2022
430954a
[Dummy imports] Better error message (#795)
patrickvonplaten Oct 12, 2022
6b6f92f
Revert an accidental commit
anton-l Oct 12, 2022
1b6749e
add or fix license formatting in models directory (#808)
Oct 12, 2022
9f8b13f
[train_text2image] Fix EMA and make it compatible with deepspeed. (#813)
patil-suraj Oct 12, 2022
4b894e3
Fix fine-tuning compatibility with deepspeed (#816)
pink-red Oct 12, 2022
ae70c6f
Add diffusers version and pipeline class to the Hub UA (#814)
anton-l Oct 12, 2022
8dc3311
[Flax] Add test (#824)
patrickvonplaten Oct 13, 2022
c56ff93
update flax scheduler API (#822)
patil-suraj Oct 13, 2022
a588c90
Fix dreambooth loss type with prior_preservation and fp16 (#826)
anton-l Oct 13, 2022
e16fb86
Fix type mismatch error, add tests for negative prompts (#823)
anton-l Oct 13, 2022
4186e6e
Give more customizable options for safety checker (#815)
patrickvonplaten Oct 13, 2022
d270300
Flax safety checker (#825)
pcuenca Oct 13, 2022
a410459
Align PT and Flax API - allow loading checkpoint from PyTorch configs…
patrickvonplaten Oct 13, 2022
83ced1e
[Flax] Complete tests (#828)
patrickvonplaten Oct 13, 2022
a9a22e3
Release: 5.0.0 (#830)
anton-l Oct 13, 2022
256a5cd
[FlaxStableDiffusionPipeline] fix bug when nsfw is detected (#832)
patil-suraj Oct 13, 2022
27cb665
Release 0 5 1 (#833)
patrickvonplaten Oct 13, 2022
6a472c2
[Community] One step unet (#840)
patrickvonplaten Oct 14, 2022
e0cbba0
Remove unneeded use_auth_token (#839)
osanseviero Oct 14, 2022
61a4fdb
Bump to 0.6.0.dev0 (#831)
anton-l Oct 14, 2022
8e4792f
Remove the last of ["sample"] (#842)
anton-l Oct 14, 2022
0671472
Fix Flax pipeline: width and height are ignored #838 (#848)
camenduru Oct 14, 2022
7a311c6
[DeviceMap] Make sure stable diffusion can be loaded from older trans…
patrickvonplaten Oct 16, 2022
8c978c1
Update README.md
patrickvonplaten Oct 17, 2022
1158684
Update README.md
patrickvonplaten Oct 17, 2022
9bbd287
Add Stable Diffusion Interpolation Example (#862)
nateraw Oct 17, 2022
4f0540f
Update README.md
patrickvonplaten Oct 17, 2022
7296493
All in one Stable Diffusion Pipeline (#821)
patrickvonplaten Oct 17, 2022
0e86077
Fix small community pipeline import bug and finish README (#869)
patrickvonplaten Oct 17, 2022
6d89d09
Update README.md
patrickvonplaten Oct 17, 2022
6dff127
Fix training push_to_hub (unconditional image generation): models wer…
pcuenca Oct 17, 2022
e3b14fc
Fix table in community README.md (#879)
nateraw Oct 17, 2022
d62304b
Add generic inference example to community pipeline readme (#874)
apolinario Oct 17, 2022
a60ed3c
Rename frame filename in interpolation community example (#881)
nateraw Oct 17, 2022
30ea4c2
Add Apple M1 tests (#796)
anton-l Oct 17, 2022
95d6abb
Fix autoencoder test (#886)
pcuenca Oct 17, 2022
03d4040
Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)
anton-l Oct 18, 2022
ec7a08d
Fix DDIM on Windows not using int64 for timesteps (#819)
hafriedlander Oct 18, 2022
1eddce8
[dreambooth] allow fine-tuning text encoder (#883)
patil-suraj Oct 18, 2022
cd179a6
Stable Diffusion image-to-image and inpaint using onnx. (#552)
zledas Oct 18, 2022
aee9614
Improve ONNX img2img numpy handling, temporarily fix the tests (#899)
anton-l Oct 19, 2022
2e37ad9
make fix copies
patrickvonplaten Oct 19, 2022
a008713
[Stable Diffusion Inpainting] Deprecate inpainting pipeline in favor …
patrickvonplaten Oct 19, 2022
dfc350a
[Communit Pipeline] Make sure "mega" uses correct inpaint pipeline (#…
patrickvonplaten Oct 19, 2022
23ac047
Stable diffusion inpainting. (#904)
patil-suraj Oct 19, 2022
fc96721
finish tests (#909)
patrickvonplaten Oct 19, 2022
8560585
ONNX supervised inpainting (#906)
anton-l Oct 19, 2022
117a02a
Initial docs update for new in-painting pipeline (#910)
pcuenca Oct 19, 2022
6e36920
Release: 0.6.0
anton-l Oct 19, 2022
23579b9
[Community Pipelines] Long Prompt Weighting Stable Diffusion Pipeline…
SkyTNT Oct 19, 2022
bde7046
[Stable Diffusion] Add components function (#889)
patrickvonplaten Oct 20, 2022
50a50f4
[PNDM Scheduler] Make sure list cannot grow forever (#882)
patrickvonplaten Oct 20, 2022
f997c4e
[DiffusionPipeline.from_pretrained] add warning when passing unused k…
patrickvonplaten Oct 20, 2022
0f89280
DOC Dreambooth Add --sample_batch_size=1 to the 8 GB dreambooth examp…
leszekhanusz Oct 20, 2022
c387e48
[Examples] add speech to image pipeline example (#897)
MikailINTech Oct 20, 2022
759e3a3
[dreambooth] dont use safety check when generating prior images (#922)
patil-suraj Oct 20, 2022
82c1f6d
Dreambooth class image generation: using unique names to avoid overwr…
leszekhanusz Oct 20, 2022
20b69cf
fix test_components (#928)
patil-suraj Oct 20, 2022
8409fba
Fix Compatibility with Nvidia NGC Containers (#919)
tasercake Oct 20, 2022
30a933e
[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffus…
SkyTNT Oct 20, 2022
cfdea72
Bump the version to 0.7.0.dev0 (#912)
anton-l Oct 20, 2022
21d570d
Introduce the copy mechanism (#924)
anton-l Oct 20, 2022
5749075
Merge branch 'prathikrao/ort-integration' of https://github.com/prath…
Oct 26, 2022
b75cc61
updated unet tests
Oct 26, 2022
b07538e
[gradient checkpointing] lower tolerance for test (#652)
patil-suraj Sep 29, 2022
aa5d5c7
[Tests] Add accelerate to testing (#729)
patrickvonplaten Oct 5, 2022
9059061
Created using Colaboratory
patil-suraj Oct 6, 2022
c3043ab
remove bogus folder no.2
patrickvonplaten Oct 7, 2022
5e3f6f2
Add generic inference example to community pipeline readme (#874)
apolinario Oct 17, 2022
6eeeaaf
Add Apple M1 tests (#796)
anton-l Oct 17, 2022
9941b7a
Fix autoencoder test (#886)
pcuenca Oct 17, 2022
205cf55
updated unet tests
Oct 26, 2022
cf02e9b
Merge branch 'prathikrao/ort-integration' of https://github.com/prath…
Oct 26, 2022
11a6b99
remove random code add
Oct 26, 2022
05534d1
remove random code add
Oct 26, 2022
01392ee
[Tests] Add accelerate to testing (#729)
patrickvonplaten Oct 5, 2022
05d2f08
Created using Colaboratory
patil-suraj Oct 6, 2022
7a6f2df
remove bogus folder no.2
patrickvonplaten Oct 7, 2022
910e821
Add generic inference example to community pipeline readme (#874)
apolinario Oct 17, 2022
7af99fc
Add Apple M1 tests (#796)
anton-l Oct 17, 2022
36c671e
Fix autoencoder test (#886)
pcuenca Oct 17, 2022
0e087d5
rebase
Oct 26, 2022
a3d8758
[Tests] Add accelerate to testing (#729)
patrickvonplaten Oct 5, 2022
3782e87
Created using Colaboratory
patil-suraj Oct 6, 2022
82781f4
remove bogus folder no.2
patrickvonplaten Oct 7, 2022
88c39c5
Add generic inference example to community pipeline readme (#874)
apolinario Oct 17, 2022
89f664f
Add Apple M1 tests (#796)
anton-l Oct 17, 2022
d50f071
Fix autoencoder test (#886)
pcuenca Oct 17, 2022
c50019f
rebase
Oct 26, 2022
e8e2917
remove random code add
Oct 26, 2022
17ab2d3
Merge branch 'prathikrao/ort-integration' of https://github.com/prath…
Oct 26, 2022
d6df970
seperate script for ort
Oct 31, 2022
67243b4
formatting
Oct 31, 2022
eafdd59
removed random eval statement
Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def transforms(examples):

with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps).sample
noise_pred = model(noisy_images, timesteps)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

Expand Down
251 changes: 251 additions & 0 deletions examples/unconditional_image_generation/train_unconditional_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import argparse
import math
import os

import torch
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
from onnxruntime.training.ortmodule import ORTModule


logger = get_logger(__name__)


def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
)

model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
model = ORTModule(model)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
]
)

if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
split="train",
)
else:
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")

def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}

dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)

if args.push_to_hub:
repo = init_git_repo(args, at_init=True)

if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)

global_step = 0
for epoch in range(args.num_epochs):
model.train()
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
).long()

# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
if args.use_ema:
ema_model.step(model)
optimizer.zero_grad()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
progress_bar.close()

accelerator.wait_for_everyone()

# Generate sample images for visual inspection
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler,
)

generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
accelerator.wait_for_everyone()

accelerator.end_training()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_images_epochs", type=int, default=10)
parser.add_argument("--save_model_epochs", type=int, default=10)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--adam_beta1", type=float, default=0.95)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--ort", action="store_true")

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank

if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

main(args)
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,4 @@ def forward(
if not return_dict:
return (sample,)

return UNet2DOutput(sample=sample)
return sample
3 changes: 3 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
},
"onnxruntime.training": {
"ORTModule": ["save_pretrained", "from_pretrained"],
}
}

ALL_IMPORTABLE_CLASSES = {}
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __call__(

for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
model_output = self.unet(image, t)

# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
Expand Down
12 changes: 6 additions & 6 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_from_pretrained_hub(self):
self.assertEqual(len(loading_info["missing_keys"]), 0)

model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.dummy_input)

assert image is not None, "Make sure output is not None"

Expand All @@ -141,7 +141,7 @@ def test_from_pretrained_accelerate(self):
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.dummy_input)

assert image is not None, "Make sure output is not None"

Expand Down Expand Up @@ -219,7 +219,7 @@ def test_output_pretrained(self):
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

with torch.no_grad():
output = model(noise, time_step).sample
output = model(noise, time_step)

output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_gradient_checkpointing(self):

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
out_2 = model_2(**inputs_dict)
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_output_pretrained_ve_mid(self):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

with torch.no_grad():
output = model(noise, time_step).sample
output = model(noise, time_step)

output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
Expand All @@ -437,7 +437,7 @@ def test_output_pretrained_ve_large(self):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

with torch.no_grad():
output = model(noise, time_step).sample
output = model(noise, time_step)

output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
Expand Down