Skip to content

DiT Pipeline #1806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 107 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
351c9a7
added dit model
kashif Dec 21, 2022
a833b3e
import
kashif Dec 22, 2022
9ec5e44
initial pipeline
kashif Dec 22, 2022
5c7d4f3
initial convert script
kashif Dec 22, 2022
037cfc7
initial pipeline
kashif Dec 22, 2022
9259d9f
make style
kashif Dec 22, 2022
bcc7b04
raise valueerror
kashif Dec 22, 2022
a3f7f83
single function
kashif Dec 22, 2022
0e26563
rename classes
kashif Dec 22, 2022
5cbcf53
use DDIMScheduler
kashif Dec 22, 2022
75b809c
Merge branch 'main' into dit
kashif Dec 22, 2022
a0dbe46
timesteps embedder
kashif Dec 22, 2022
f2a074e
samples to cpu
kashif Dec 23, 2022
68dfd3f
fix var names
kashif Dec 23, 2022
c6c476f
fix numpy type
kashif Dec 23, 2022
a719508
use timesteps class for proj
kashif Dec 25, 2022
e61a606
fix typo
kashif Dec 25, 2022
83b9607
fix arg name
kashif Dec 25, 2022
1f127e0
flip_sin_to_cos and better var names
kashif Dec 28, 2022
b8ec4ef
fix C shape cal
kashif Dec 28, 2022
af59837
make style
kashif Dec 28, 2022
f6b1fb1
remove unused imports
kashif Dec 28, 2022
2cc64b8
cleanup
kashif Dec 29, 2022
d9d1bb3
add back patch_size
kashif Dec 29, 2022
c048ba2
initial dit doc
kashif Dec 29, 2022
ab6773b
typo
kashif Dec 29, 2022
96bff9a
Update docs/source/api/pipelines/dit.mdx
kashif Dec 29, 2022
f8a42e3
added copyright license headers
kashif Dec 29, 2022
7467114
Merge branch 'dit' of https://github.com/kashif/diffusers into dit
kashif Dec 29, 2022
25d6809
added example usage and toc
kashif Dec 29, 2022
08d0b93
fix variable names asserts
kashif Dec 29, 2022
989dac8
remove comment
kashif Dec 29, 2022
aa3501f
added docs
kashif Dec 29, 2022
307f072
fix typo
kashif Dec 29, 2022
5518786
upstream changes
kashif Dec 29, 2022
2fcbfed
set proper device for drop_ids
kashif Dec 30, 2022
b75c90e
added initial dit pipeline test
kashif Dec 30, 2022
6856242
Merge branch 'main' into dit
kashif Dec 30, 2022
4530648
update docs
kashif Dec 30, 2022
98d594d
fix imports
kashif Dec 30, 2022
ecf643b
make fix-copies
kashif Dec 30, 2022
635f03c
isort
kashif Dec 30, 2022
abde319
fix imports
kashif Dec 30, 2022
f5ba639
get rid of more magic numbers
kashif Dec 31, 2022
29a4a15
fix code when guidance is off
kashif Dec 31, 2022
988a38c
remove block_kwargs
kashif Jan 2, 2023
95bc036
cleanup script
kashif Jan 2, 2023
9fe5fd7
Merge branch 'main' into dit
kashif Jan 2, 2023
951dcac
removed to_2tuple
kashif Jan 2, 2023
fa08b52
use FeedForward class instead of another MLP
kashif Jan 3, 2023
aea1495
style
kashif Jan 3, 2023
5646549
Merge branch 'main' into dit
kashif Jan 4, 2023
7511634
work on mergint DiTBlock with BasicTransformerBlock
williamberman Jan 8, 2023
18b0ab8
added missing final_dropout and args to BasicTransformerBlock
kashif Jan 8, 2023
3ea0809
use norm from block
kashif Jan 9, 2023
ef32b53
fix arg
kashif Jan 9, 2023
997d68f
remove unused arg
kashif Jan 9, 2023
0826d95
fix call to class_embedder
kashif Jan 9, 2023
7817338
use timesteps
kashif Jan 9, 2023
e334853
make style
kashif Jan 9, 2023
a412aa9
attn_output gets multiplied
kashif Jan 9, 2023
7744b1d
removed commented code
kashif Jan 9, 2023
d69bdb0
use Transformer2D
kashif Jan 10, 2023
7672549
use self.is_input_patches
kashif Jan 10, 2023
ded56d9
fix flags
kashif Jan 10, 2023
852750c
fixed conversion to use Transformer2DModel
kashif Jan 10, 2023
edbadee
fixes for pipeline
kashif Jan 10, 2023
e90c2a0
remove dit.py
kashif Jan 11, 2023
bd0668d
Merge remote-tracking branch 'upstream/main' into dit
kashif Jan 11, 2023
f4d034e
fix timesteps device
kashif Jan 11, 2023
5c47f06
use randn_tensor and fix fp16 inf.
kashif Jan 11, 2023
dcdd94f
timesteps_emb already the right dtype
kashif Jan 11, 2023
ef81931
fix dit test class
kashif Jan 11, 2023
6c96e25
fix test and style
kashif Jan 13, 2023
3f646a0
fix norm2 usage in vq-diffusion
kashif Jan 13, 2023
3bd1795
added author names to pipeline and lmagenet labels link
kashif Jan 13, 2023
956dc5b
fix tests
kashif Jan 13, 2023
7fd1e3f
use norm_type as string
kashif Jan 13, 2023
372a0cc
rename dit to transformer
kashif Jan 13, 2023
f76b56a
Merge branch 'main' into dit
kashif Jan 13, 2023
f78f067
fix name
kashif Jan 13, 2023
46c10d3
fix test
kashif Jan 13, 2023
ad21291
set norm_type = "layer" by default
kashif Jan 13, 2023
2c2bf5e
Merge branch 'main' into dit
kashif Jan 16, 2023
5aa3ebb
Merge branch 'main' into dit
kashif Jan 16, 2023
95df97e
Merge branch 'main' into dit
kashif Jan 16, 2023
5a03801
fix tests
kashif Jan 16, 2023
72cfe79
do not skip common tests
kashif Jan 16, 2023
0611450
Update src/diffusers/models/attention.py
kashif Jan 17, 2023
1662798
revert AdaLayerNorm API
kashif Jan 17, 2023
b7bf49c
fix norm_type name
kashif Jan 17, 2023
9a1fddf
make sure all components are in eval mode
kashif Jan 17, 2023
5867e3f
revert norm2 API
kashif Jan 17, 2023
df44ae5
compact
kashif Jan 17, 2023
3b91645
Merge branch 'main' into dit
kashif Jan 17, 2023
6fb598d
finish deprecation
patrickvonplaten Jan 17, 2023
13f5fb6
add slow tests
patrickvonplaten Jan 17, 2023
28afb02
remove @
patrickvonplaten Jan 17, 2023
79ff6d2
refactor some stuff
patrickvonplaten Jan 17, 2023
e84159d
upload
patrickvonplaten Jan 17, 2023
25d4d0b
Update src/diffusers/pipelines/dit/pipeline_dit.py
patrickvonplaten Jan 17, 2023
956e221
finish more
patrickvonplaten Jan 17, 2023
8d159ac
finish docs
patrickvonplaten Jan 17, 2023
2d73a61
improve docs
patrickvonplaten Jan 17, 2023
17ab998
Merge branch 'dit' of https://github.com/kashif/diffusers into dit
patrickvonplaten Jan 17, 2023
66f051e
finish docs
patrickvonplaten Jan 17, 2023
bc652a5
Merge branch 'dit' of https://github.com/kashif/diffusers into dit
patrickvonplaten Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions scripts/convert_dit_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import argparse

import torch
from torchvision.datasets.utils import download_url

from diffusers import DiT, AutoencoderKL, DDPMScheduler, DiTPipeline

pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}


def download_model(model_name):
"""
Downloads a pre-trained DiT model from the web.
"""
local_path = f"pretrained_models/{model_name}"
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
download_url(web_path, "pretrained_models")
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model


def main(args):
state_dict = download_model(pretrained_models[args.image_size])
vae = AutoencoderKL.from_pretrained(args.vae_model)

state_dict["t_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
state_dict["t_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
state_dict["t_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
state_dict["t_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
state_dict.pop("t_embedder.mlp.0.weight")
state_dict.pop("t_embedder.mlp.0.bias")
state_dict.pop("t_embedder.mlp.2.weight")
state_dict.pop("t_embedder.mlp.2.bias")

for depth in range(28):
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)

state_dict[f"blocks.{depth}.attn.to_q.weight"] = q
state_dict[f"blocks.{depth}.attn.to_q.bias"] = q_bias
state_dict[f"blocks.{depth}.attn.to_k.weight"] = k
state_dict[f"blocks.{depth}.attn.to_k.bias"] = k_bias
state_dict[f"blocks.{depth}.attn.to_v.weight"] = v
state_dict[f"blocks.{depth}.attn.to_v.bias"] = v_bias

state_dict[f"blocks.{depth}.attn.to_out.0.weight"] = state_dict[f"blocks.{depth}.attn.proj.weight"]
state_dict[f"blocks.{depth}.attn.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]

state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
state_dict.pop(f"blocks.{depth}.attn.proj.bias")

dit = DiT(
input_size=args.image_size // 8,
depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
)
dit.load_state_dict(state_dict)

scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_schedule="linear",
prediction_type="epsilon",
)

pipeline = DiTPipeline(dit=dit, vae=vae, scheduler=scheduler)

if args.save:
pipeline.save_pretrained(args.checkpoint_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--image_size",
default=256,
type=int,
required=False,
help="Image size of pretrained model, either 256 or 512.",
)
parser.add_argument(
"--vae_model",
default="stabilityai/sd-vae-ft-ema",
type=str,
required=False,
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
)
parser.add_argument(
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
)

args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .modeling_utils import ModelMixin
from .models import (
AutoencoderKL,
DiT,
PriorTransformer,
Transformer2DModel,
UNet1DModel,
Expand All @@ -48,6 +49,7 @@
DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
DiTPipeline,
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if is_torch_available():
from .attention import Transformer2DModel
from .dit import DiT
from .prior_transformer import PriorTransformer
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
Expand Down
Loading