Skip to content

Commit 7cfc0fa

Browse files
[APPS-SD] Fix a few bugs and bring it up to speed with SD CLI (huggingface#908)
1 parent a908121 commit 7cfc0fa

File tree

11 files changed

+396
-108
lines changed

11 files changed

+396
-108
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,5 @@ tank/dict_configs.py
170170
cache_models/
171171
onnx_models/
172172

173-
#web logging
174-
web/logs/
175-
web/stored_results/stable_diffusion/
173+
# Generated images
174+
generated_imgs/

apps/stable_diffusion/scripts/txt2img.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class Config:
4141
for vmfb in vmfbs:
4242
if os.path.exists(vmfb):
4343
os.remove(vmfb)
44+
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
45+
# TODO: Remove this once we have better weight updation logic.
46+
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
47+
for yaml in inference_yaml:
48+
if os.path.exists(yaml):
49+
os.remove(yaml)
4450
home = os.path.expanduser("~")
4551
if os.name == "nt": # Windows
4652
appdata = os.getenv("LOCALAPPDATA")

apps/stable_diffusion/src/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
get_unet,
77
get_clip,
88
get_tokenizer,
9+
get_params,
10+
get_variant_version,
911
)

apps/stable_diffusion/src/models/model_wrappers.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from transformers import CLIPTextModel
33
from collections import defaultdict
44
import torch
5-
import sys
65
import traceback
76
import re
7+
import os, sys, functools, operator
88
from apps.stable_diffusion.src.utils import (
99
compile_through_fx,
1010
get_opt_flags,
1111
base_models,
1212
args,
13+
get_vmfb_path_name,
1314
)
1415

1516

@@ -68,6 +69,7 @@ def __init__(
6869
height: int = 512,
6970
batch_size: int = 1,
7071
use_base_vae: bool = False,
72+
use_tuned: bool = False,
7173
):
7274
self.check_params(max_len, width, height)
7375
self.max_len = max_len
@@ -88,13 +90,15 @@ def __init__(
8890
+ "_"
8991
+ precision
9092
)
93+
self.use_tuned = use_tuned
9194
# We need a better naming convention for the .vmfbs because despite
9295
# using the custom model variant the .vmfb names remain the same and
9396
# it'll always pick up the compiled .vmfb instead of compiling the
9497
# custom model.
9598
# So, currently, we add `self.model_id` in the `self.model_name` of
9699
# .vmfb file.
97100
# TODO: Have a better way of naming the vmfbs using self.model_name.
101+
import re
98102

99103
model_name = re.sub(r"\W+", "_", self.model_id)
100104
if model_name[0] == "_":
@@ -137,6 +141,7 @@ def forward(self, input):
137141
vae,
138142
inputs,
139143
is_f16=is_f16,
144+
use_tuned=self.use_tuned,
140145
model_name=vae_name + self.model_name,
141146
extra_args=get_opt_flags("vae", precision=self.precision),
142147
)
@@ -177,6 +182,7 @@ def forward(
177182
model_name="unet" + self.model_name,
178183
is_f16=is_f16,
179184
f16_input_mask=input_mask,
185+
use_tuned=self.use_tuned,
180186
extra_args=get_opt_flags("unet", precision=self.precision),
181187
)
182188
return shark_unet
@@ -194,7 +200,6 @@ def forward(self, input):
194200
return self.text_encoder(input)[0]
195201

196202
clip_model = CLIPText()
197-
198203
shark_clip = compile_through_fx(
199204
clip_model,
200205
tuple(self.inputs["clip"]),
@@ -204,6 +209,11 @@ def forward(self, input):
204209
return shark_clip
205210

206211
def __call__(self):
212+
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
213+
vmfb_path = [
214+
get_vmfb_path_name(model + self.model_name)[0]
215+
for model in model_name
216+
]
207217
for model_id in base_models:
208218
self.inputs = get_input_info(
209219
base_models[model_id],
@@ -213,12 +223,22 @@ def __call__(self):
213223
self.batch_size,
214224
)
215225
try:
216-
compiled_clip = self.get_clip()
217226
compiled_unet = self.get_unet()
218227
compiled_vae = self.get_vae()
228+
compiled_clip = self.get_clip()
219229
except Exception as e:
220230
if args.enable_stack_trace:
221231
traceback.print_exc()
232+
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
233+
all_vmfb_present = functools.reduce(
234+
operator.__and__, vmfb_present
235+
)
236+
# We need to delete vmfbs only if some of the models were compiled.
237+
if not all_vmfb_present:
238+
for i in range(len(vmfb_path)):
239+
if vmfb_present[i]:
240+
os.remove(vmfb_path[i])
241+
print("Deleted: ", vmfb_path[i])
222242
print("Retrying with a different base model configuration")
223243
continue
224244
# This is done just because in main.py we are basing the choice of tokenizer and scheduler

apps/stable_diffusion/src/models/opt_params.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
}
1515

1616

17+
def get_variant_version(hf_model_id):
18+
return hf_model_variant_map[hf_model_id]
19+
20+
1721
def get_params(bucket_key, model_key, model, is_tuned, precision):
1822
iree_flags = []
1923
if len(args.iree_vulkan_target_triple) > 0:
@@ -60,7 +64,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
6064

6165

6266
def get_unet():
63-
variant, version = hf_model_variant_map[args.hf_model_id]
67+
variant, version = get_variant_version(args.hf_model_id)
6468
# Tuned model is present only for `fp16` precision.
6569
is_tuned = "tuned" if args.use_tuned else "untuned"
6670
if "vulkan" not in args.device and args.use_tuned:
@@ -77,7 +81,7 @@ def get_unet():
7781

7882

7983
def get_vae():
80-
variant, version = hf_model_variant_map[args.hf_model_id]
84+
variant, version = get_variant_version(args.hf_model_id)
8185
# Tuned model is present only for `fp16` precision.
8286
is_tuned = "tuned" if args.use_tuned else "untuned"
8387
is_base = "/base" if args.use_base_vae else ""
@@ -95,7 +99,7 @@ def get_vae():
9599

96100

97101
def get_clip():
98-
variant, version = hf_model_variant_map[args.hf_model_id]
102+
variant, version = get_variant_version(args.hf_model_id)
99103
bucket_key = f"{variant}/untuned"
100104
model_key = (
101105
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"

apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,12 @@ def from_pretrained(
185185
width: int,
186186
use_base_vae: bool,
187187
):
188-
init_kwargs = None
189188
if import_mlir:
190-
if ckpt_loc:
191-
preprocessCKPT()
189+
if ckpt_loc != "":
190+
assert ckpt_loc.lower().endswith(
191+
(".ckpt", ".safetensors")
192+
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
193+
ckpt_loc = preprocessCKPT()
192194
mlir_import = SharkifyStableDiffusionModel(
193195
model_id,
194196
ckpt_loc,

apps/stable_diffusion/src/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
opt_flags,
1010
resource_path,
1111
)
12+
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
1213
from apps.stable_diffusion.src.utils.stable_args import args
1314
from apps.stable_diffusion.src.utils.utils import (
15+
get_vmfb_path_name,
1416
get_shark_model,
1517
compile_through_fx,
1618
set_iree_runtime_flags,
Lines changed: 85 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,101 @@
1-
{
2-
"unet": {
3-
"tuned": {
4-
"fp16": {
5-
"default_compilation_flags": []
6-
},
7-
"fp32": {
8-
"default_compilation_flags": []
9-
}
1+
{
2+
"unet": {
3+
"tuned": {
4+
"fp16": {
5+
"default_compilation_flags": []
106
},
11-
"untuned": {
12-
"fp16": {
13-
"default_compilation_flags": [
14-
"--iree-flow-enable-padding-linalg-ops",
15-
"--iree-flow-linalg-ops-padding-size=32"
16-
],
17-
"specified_compilation_flags": {
18-
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
19-
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
20-
}
21-
},
22-
"fp32": {
23-
"default_compilation_flags": [
24-
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
25-
"--iree-flow-enable-padding-linalg-ops",
26-
"--iree-flow-linalg-ops-padding-size=16"
27-
]
28-
}
7+
"fp32": {
8+
"default_compilation_flags": []
299
}
3010
},
31-
"vae": {
32-
"tuned": {
33-
"fp16": {
34-
"default_compilation_flags": [
35-
"--iree-flow-enable-padding-linalg-ops",
36-
"--iree-flow-linalg-ops-padding-size=32",
37-
"--iree-flow-enable-conv-img2col-transform"
38-
]
39-
},
40-
"fp32": {
41-
"default_compilation_flags": [
42-
"--iree-flow-enable-padding-linalg-ops",
43-
"--iree-flow-linalg-ops-padding-size=32",
44-
"--iree-flow-enable-conv-img2col-transform"
45-
]
11+
"untuned": {
12+
"fp16": {
13+
"default_compilation_flags": [
14+
"--iree-flow-enable-padding-linalg-ops",
15+
"--iree-flow-linalg-ops-padding-size=32"
16+
],
17+
"specified_compilation_flags": {
18+
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
19+
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
4620
}
4721
},
48-
"untuned": {
49-
"fp16": {
50-
"default_compilation_flags": [
22+
"fp32": {
23+
"default_compilation_flags": [
24+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
25+
"--iree-flow-enable-padding-linalg-ops",
26+
"--iree-flow-linalg-ops-padding-size=16"
27+
]
28+
}
29+
}
30+
},
31+
"vae": {
32+
"tuned": {
33+
"fp16": {
34+
"default_compilation_flags": [],
35+
"specified_compilation_flags": {
36+
"cuda": [],
37+
"default_device": ["--iree-flow-enable-padding-linalg-ops",
38+
"--iree-flow-linalg-ops-padding-size=32",
39+
"--iree-flow-enable-conv-img2col-transform"]
40+
}
41+
},
42+
"fp32": {
43+
"default_compilation_flags": [],
44+
"specified_compilation_flags": {
45+
"cuda": [],
46+
"default_device": [
5147
"--iree-flow-enable-padding-linalg-ops",
5248
"--iree-flow-linalg-ops-padding-size=32",
5349
"--iree-flow-enable-conv-img2col-transform"
5450
]
55-
},
56-
"fp32": {
57-
"default_compilation_flags": [
58-
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
59-
"--iree-flow-enable-padding-linalg-ops",
60-
"--iree-flow-linalg-ops-padding-size=16"
61-
]
6251
}
6352
}
6453
},
65-
"clip": {
66-
"tuned": {
67-
"fp16": {
68-
"default_compilation_flags": [
69-
"--iree-flow-linalg-ops-padding-size=16",
70-
"--iree-flow-enable-padding-linalg-ops"
71-
]
72-
},
73-
"fp32": {
74-
"default_compilation_flags": [
75-
"--iree-flow-linalg-ops-padding-size=16",
76-
"--iree-flow-enable-padding-linalg-ops"
77-
]
78-
}
54+
"untuned": {
55+
"fp16": {
56+
"default_compilation_flags": [
57+
"--iree-flow-enable-padding-linalg-ops",
58+
"--iree-flow-linalg-ops-padding-size=32",
59+
"--iree-flow-enable-conv-img2col-transform"
60+
]
7961
},
80-
"untuned": {
81-
"fp16": {
82-
"default_compilation_flags": [
83-
"--iree-flow-linalg-ops-padding-size=16",
84-
"--iree-flow-enable-padding-linalg-ops"
85-
]
86-
},
87-
"fp32": {
88-
"default_compilation_flags": [
89-
"--iree-flow-linalg-ops-padding-size=16",
90-
"--iree-flow-enable-padding-linalg-ops"
91-
]
92-
}
62+
"fp32": {
63+
"default_compilation_flags": [
64+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
65+
"--iree-flow-enable-padding-linalg-ops",
66+
"--iree-flow-linalg-ops-padding-size=16"
67+
]
68+
}
69+
}
70+
},
71+
"clip": {
72+
"tuned": {
73+
"fp16": {
74+
"default_compilation_flags": [
75+
"--iree-flow-linalg-ops-padding-size=16",
76+
"--iree-flow-enable-padding-linalg-ops"
77+
]
78+
},
79+
"fp32": {
80+
"default_compilation_flags": [
81+
"--iree-flow-linalg-ops-padding-size=16",
82+
"--iree-flow-enable-padding-linalg-ops"
83+
]
84+
}
85+
},
86+
"untuned": {
87+
"fp16": {
88+
"default_compilation_flags": [
89+
"--iree-flow-linalg-ops-padding-size=16",
90+
"--iree-flow-enable-padding-linalg-ops"
91+
]
92+
},
93+
"fp32": {
94+
"default_compilation_flags": [
95+
"--iree-flow-linalg-ops-padding-size=16",
96+
"--iree-flow-enable-padding-linalg-ops"
97+
]
9398
}
9499
}
95100
}
101+
}

0 commit comments

Comments
 (0)