Skip to content

Commit 2bc6de6

Browse files
authored
[SD] Add support for a compiled version of the discrete Euler scheduler (huggingface#657)
* Add Shark version of euler scheduler * Add Shark version of euler scheduler to web ui
1 parent ffef168 commit 2bc6de6

File tree

10 files changed

+345
-28
lines changed

10 files changed

+345
-28
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from stable_args import args
1818
from utils import get_shark_model, set_iree_runtime_flags
1919
from opt_params import get_unet, get_vae, get_clip
20+
from schedulers import (
21+
SharkEulerDiscreteScheduler,
22+
)
2023
import time
2124
import sys
2225
from shark.iree_utils.compile_utils import dump_isas
@@ -78,6 +81,7 @@ def end_profiling(device):
7881
"CompVis/stable-diffusion-v1-4",
7982
subfolder="scheduler",
8083
)
84+
cpu_scheduling = True
8185
if args.version == "v2.1":
8286
tokenizer = CLIPTokenizer.from_pretrained(
8387
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
@@ -93,10 +97,19 @@ def end_profiling(device):
9397
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
9498
)
9599

96-
scheduler = EulerDiscreteScheduler.from_pretrained(
97-
"stabilityai/stable-diffusion-2-1-base",
98-
subfolder="scheduler",
99-
)
100+
if args.use_compiled_scheduler:
101+
scheduler = SharkEulerDiscreteScheduler.from_pretrained(
102+
"stabilityai/stable-diffusion-2-1-base",
103+
subfolder="scheduler",
104+
)
105+
scheduler.compile()
106+
cpu_scheduling = False
107+
else:
108+
scheduler = EulerDiscreteScheduler.from_pretrained(
109+
"stabilityai/stable-diffusion-2-1-base",
110+
subfolder="scheduler",
111+
)
112+
100113
start = time.time()
101114

102115
text_input = tokenizer(
@@ -144,36 +157,42 @@ def end_profiling(device):
144157
print(f"i = {i} t = {t}", end="")
145158
timestep = torch.tensor([t]).to(dtype).detach().numpy()
146159
latent_model_input = scheduler.scale_model_input(latents, t)
147-
latents_numpy = latent_model_input.detach().numpy()
160+
if cpu_scheduling:
161+
latent_model_input = latent_model_input.detach().numpy()
148162

149163
profile_device = start_profiling(file_path="unet.rdc")
150164

151165
noise_pred = unet.forward(
152166
(
153-
latents_numpy,
167+
latent_model_input,
154168
timestep,
155169
text_embeddings_numpy,
156170
guidance_scale,
157-
)
171+
),
172+
send_to_host=False,
158173
)
159174

160175
end_profiling(profile_device)
161176

162-
noise_pred = torch.from_numpy(noise_pred)
177+
if cpu_scheduling:
178+
noise_pred = torch.from_numpy(noise_pred.to_host())
179+
latents = scheduler.step(noise_pred, t, latents).prev_sample
180+
else:
181+
latents = scheduler.step(noise_pred, t, latents)
163182
step_time = time.time() - step_start
164183
avg_ms += step_time
165184
step_ms = int((step_time) * 1000)
166185
print(f" ({step_ms}ms)")
167186

168-
latents = scheduler.step(noise_pred, t, latents).prev_sample
169-
170187
avg_ms = 1000 * avg_ms / args.steps
171188
print(f"Average step time: {avg_ms}ms/it")
172189

173190
# scale and decode the image latents with vae
174191
latents = 1 / 0.18215 * latents
175192
# latents = latents.
176-
latents_numpy = latents.detach().numpy()
193+
latents_numpy = latents
194+
if cpu_scheduling:
195+
latents_numpy = latents.detach().numpy()
177196
profile_device = start_profiling(file_path="vae.rdc")
178197
vae_start = time.time()
179198
image = vae.forward((latents_numpy,))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import sys
2+
import numpy as np
3+
from typing import List, Optional, Tuple, Union
4+
from diffusers import (
5+
LMSDiscreteScheduler,
6+
PNDMScheduler,
7+
DDIMScheduler,
8+
DPMSolverMultistepScheduler,
9+
EulerDiscreteScheduler,
10+
)
11+
from diffusers.configuration_utils import register_to_config
12+
from utils import compile_through_fx, get_shark_model
13+
from stable_args import args
14+
import torch
15+
16+
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
17+
18+
model_input = {
19+
"euler": {
20+
"latent": torch.randn(1, 4, 64, 64),
21+
"output": torch.randn(1, 4, 64, 64),
22+
"sigma": torch.tensor(1).to(torch.float32),
23+
"dt": torch.tensor(1).to(torch.float32),
24+
},
25+
}
26+
27+
28+
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
29+
@register_to_config
30+
def __init__(
31+
self,
32+
num_train_timesteps: int = 1000,
33+
beta_start: float = 0.0001,
34+
beta_end: float = 0.02,
35+
beta_schedule: str = "linear",
36+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
37+
prediction_type: str = "epsilon",
38+
):
39+
super().__init__(
40+
num_train_timesteps,
41+
beta_start,
42+
beta_end,
43+
beta_schedule,
44+
trained_betas,
45+
prediction_type,
46+
)
47+
48+
def compile(self):
49+
example_latent = model_input["euler"]["latent"]
50+
example_output = model_input["euler"]["output"]
51+
if args.precision == "fp16":
52+
example_latent = example_latent.half()
53+
example_output = example_output.half()
54+
example_sigma = model_input["euler"]["sigma"]
55+
example_dt = model_input["euler"]["dt"]
56+
57+
class ScalingModel(torch.nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, latent, sigma):
62+
return latent / ((sigma**2 + 1) ** 0.5)
63+
64+
class SchedulerStepModel(torch.nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
68+
def forward(self, noise_pred, sigma, latent, dt):
69+
pred_original_sample = latent - sigma * noise_pred
70+
derivative = (latent - pred_original_sample) / sigma
71+
return latent + derivative * dt
72+
73+
iree_flags = []
74+
if len(args.iree_vulkan_target_triple) > 0:
75+
iree_flags.append(
76+
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
77+
)
78+
# Disable bindings fusion to work with moltenVK.
79+
if sys.platform == "darwin":
80+
iree_flags.append("-iree-stream-fuse-binding=false")
81+
82+
if args.import_mlir:
83+
scaling_model = ScalingModel()
84+
self.scaling_model = compile_through_fx(
85+
scaling_model,
86+
(example_latent, example_sigma),
87+
model_name="euler_scale_model_input_" + args.precision,
88+
extra_args=iree_flags,
89+
)
90+
91+
step_model = SchedulerStepModel()
92+
self.step_model = compile_through_fx(
93+
step_model,
94+
(example_output, example_sigma, example_latent, example_dt),
95+
model_name="euler_step_" + args.precision,
96+
extra_args=iree_flags,
97+
)
98+
else:
99+
self.scaling_model = get_shark_model(
100+
SCHEDULER_BUCKET,
101+
"euler_scale_model_input_" + args.precision,
102+
iree_flags,
103+
)
104+
self.step_model = get_shark_model(
105+
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
106+
)
107+
108+
def scale_model_input(self, sample, timestep):
109+
step_index = (self.timesteps == timestep).nonzero().item()
110+
sigma = self.sigmas[step_index]
111+
return self.scaling_model.forward(
112+
(
113+
sample,
114+
sigma,
115+
),
116+
send_to_host=False,
117+
)
118+
119+
def step(self, noise_pred, timestep, latent):
120+
step_index = (self.timesteps == timestep).nonzero().item()
121+
sigma = self.sigmas[step_index]
122+
dt = self.sigmas[step_index + 1] - sigma
123+
return self.step_model.forward(
124+
(
125+
noise_pred,
126+
sigma,
127+
latent,
128+
dt,
129+
),
130+
send_to_host=False,
131+
)

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@
132132
### Misc. Debug and Optimization flags
133133
##############################################################################
134134

135+
p.add_argument(
136+
"--use_compiled_scheduler",
137+
default=False,
138+
action=argparse.BooleanOptionalAction,
139+
help="use the default scheduler precompiled into the model if available",
140+
)
141+
135142
p.add_argument(
136143
"--local_tank_cache",
137144
default="",

shark/iree_utils/compile_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,21 +348,31 @@ def export_module_to_mlir_file(module, frontend, directory: str):
348348
return filename
349349

350350

351-
def get_results(compiled_vm, input, config, frontend="torch"):
351+
def get_results(
352+
compiled_vm, input, config, frontend="torch", send_to_host=True
353+
):
352354
"""Runs a .vmfb file given inputs and config and returns output."""
353355
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
354356
result = compiled_vm(*device_inputs)
355357
result_tensors = []
356358
if isinstance(result, tuple):
357-
for val in result:
358-
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
359+
if send_to_host:
360+
for val in result:
361+
result_tensors.append(np.asarray(val, val.dtype))
362+
else:
363+
for val in result:
364+
result_tensors.append(val)
359365
return result_tensors
360366
elif isinstance(result, dict):
361367
data = list(result.items())
362-
res = np.array(data, dtype=object)
363-
return np.copy(res)
368+
if send_to_host:
369+
res = np.array(data, dtype=object)
370+
return np.copy(res)
371+
return data
364372
else:
365-
return result.to_host()
373+
if send_to_host:
374+
return result.to_host()
375+
return result
366376

367377

368378
def get_iree_runtime_config(device):

shark/shark_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def compile(self, extra_args=[]):
138138
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
139139

140140
# inputs are considered to be tuple of np.array.
141-
def forward(self, inputs: tuple):
142-
return self.shark_runner.run(inputs)
141+
def forward(self, inputs: tuple, send_to_host=True):
142+
return self.shark_runner.run(inputs, send_to_host)
143143

144144
# Captures the static input information from the mlir_module.
145145
# TODO(pashu123): Generate the input information for dynamic shapes.

shark/shark_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ def __init__(
9191
extra_args=self.extra_args,
9292
)
9393

94-
def run(self, inputs: tuple):
94+
def run(self, inputs: tuple, send_to_host=False):
9595
return get_results(
9696
self.iree_compilation_module,
9797
inputs,
9898
self.iree_config,
9999
self.mlir_dialect,
100+
send_to_host,
100101
)

web/index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,14 @@ def resource_path(relative_path):
114114
with gr.Row():
115115
scheduler_key = gr.Dropdown(
116116
label="Scheduler",
117-
value="EulerDiscrete",
117+
value="SharkEulerDiscrete",
118118
choices=[
119119
"DDIM",
120120
"PNDM",
121121
"LMSDiscrete",
122122
"DPMSolverMultistep",
123123
"EulerDiscrete",
124+
"SharkEulerDiscrete",
124125
],
125126
)
126127
with gr.Group():

web/models/stable_diffusion/cache_objects.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
1010
from models.stable_diffusion.utils import set_iree_runtime_flags
1111
from models.stable_diffusion.stable_args import args
12+
from models.stable_diffusion.schedulers import (
13+
SharkEulerDiscreteScheduler,
14+
)
1215
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
1316

1417

@@ -39,6 +42,11 @@
3942
model_config[args.version],
4043
subfolder="scheduler",
4144
)
45+
schedulers["SharkEulerDiscrete"] = SharkEulerDiscreteScheduler.from_pretrained(
46+
model_config[args.version],
47+
subfolder="scheduler",
48+
)
49+
schedulers["SharkEulerDiscrete"].compile()
4250

4351
# use tuned unet model in case of rdna3 cards.
4452
if "rdna3" in get_vulkan_triple_flag():

web/models/stable_diffusion/main.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def stable_diff_inf(
5656
cache_obj["tokenizer"],
5757
)
5858
scheduler = schedulers[scheduler_key]
59+
cpu_scheduling = not scheduler_key.startswith("Shark")
5960

6061
start = time.time()
6162
text_input = tokenizer(
@@ -104,27 +105,35 @@ def stable_diff_inf(
104105

105106
step_start = time.time()
106107
timestep = torch.tensor([t]).to(dtype).detach().numpy()
107-
latents_model_input = scheduler.scale_model_input(latents, t)
108-
latents_numpy = latents_model_input.detach().numpy()
108+
latent_model_input = scheduler.scale_model_input(latents, t)
109+
if cpu_scheduling:
110+
latent_model_input = latent_model_input.detach().numpy()
109111

110112
noise_pred = unet.forward(
111113
(
112-
latents_numpy,
114+
latent_model_input,
113115
timestep,
114116
text_embeddings_numpy,
115117
args.guidance_scale,
116-
)
118+
),
119+
send_to_host=False,
117120
)
118-
noise_pred = torch.from_numpy(noise_pred)
121+
122+
if cpu_scheduling:
123+
noise_pred = torch.from_numpy(noise_pred.to_host())
124+
latents = scheduler.step(noise_pred, t, latents).prev_sample
125+
else:
126+
latents = scheduler.step(noise_pred, t, latents)
119127
step_time = time.time() - step_start
120128
avg_ms += step_time
121129
step_ms = int((step_time) * 1000)
122130
print(f" \nIteration = {i}, Time = {step_ms}ms")
123-
latents = scheduler.step(noise_pred, t, latents)["prev_sample"]
124131

125132
# scale and decode the image latents with vae
126133
latents = 1 / 0.18215 * latents
127-
latents_numpy = latents.detach().numpy()
134+
latents_numpy = latents
135+
if cpu_scheduling:
136+
latents_numpy = latents.detach().numpy()
128137
vae_start = time.time()
129138
image = vae.forward((latents_numpy,))
130139
vae_end = time.time()

0 commit comments

Comments
 (0)