Skip to content

Commit 1939376

Browse files
author
Gaurav Shukla
authored
[WEB] Cache model parameters (huggingface#452)
This commit cache some of the model parameters to reduce the response time of shark web. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]>
1 parent 25931d4 commit 1939376

File tree

6 files changed

+75
-95
lines changed

6 files changed

+75
-95
lines changed

setup_venv.ps1

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ pip install --pre torch-mlir torch torchvision --extra-index-url https://downloa
3535
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
3636
Write-Host "Building SHARK..."
3737
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
38-
pip install diffusers transformers scipy
38+
pip install diffusers transformers scipy gradio
3939
Write-Host "Build and installation completed successfully"
4040
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

web/index.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,20 @@ def debug_event(debug):
1616
with gr.Row():
1717
with gr.Group():
1818
with gr.Column(scale=1):
19-
img = Image.open("./Nod_logo.png")
20-
gr.Image(value=img, show_label=False, interactive=False).style(
21-
height=80, width=150
22-
)
19+
nod_logo = Image.open("./logos/Nod_logo.png")
20+
gr.Image(
21+
value=nod_logo, show_label=False, interactive=False
22+
).style(height=80, width=150)
2323
with gr.Column(scale=1):
24-
gr.Label(value="Shark Models Demo.")
24+
logo2 = Image.open("./logos/other_logo.png")
25+
gr.Image(
26+
value=logo2,
27+
show_label=False,
28+
interactive=False,
29+
visible=False,
30+
).style(height=80, width=150)
31+
with gr.Column(scale=1):
32+
gr.Label(value="Ultra fast Stable Diffusion")
2533

2634
with gr.Tabs():
2735
# with gr.TabItem("ResNet50"):
@@ -136,9 +144,7 @@ def debug_event(debug):
136144
) = (
137145
device
138146
) = (
139-
load_vmfb
140-
) = (
141-
save_vmfb
147+
cache
142148
) = (
143149
iree_vulkan_target_triple
144150
) = (
@@ -160,6 +166,7 @@ def debug_event(debug):
160166
prompt = gr.Textbox(
161167
label="Prompt",
162168
value="a photograph of an astronaut riding a horse",
169+
lines=5,
163170
)
164171
ex = gr.Examples(
165172
examples=examples,
@@ -219,19 +226,18 @@ def debug_event(debug):
219226
value="42", max_lines=1, label="Seed"
220227
)
221228
with gr.Row():
222-
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
223-
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
229+
cache = gr.Checkbox(label="Cache", value=True)
224230
debug = gr.Checkbox(label="DEBUG", value=False)
225231
live_preview = gr.Checkbox(
226-
label="live preview", value=False
232+
label="Live Preview", value=False
227233
)
228234
iree_vulkan_target_triple = gr.Textbox(
229235
value="",
230236
max_lines=1,
231237
label="IREE VULKAN TARGET TRIPLE",
232238
visible=False,
233239
)
234-
stable_diffusion = gr.Button("Generate image from prompt")
240+
stable_diffusion = gr.Button("Generate Image")
235241
with gr.Column(scale=1, min_width=600):
236242
generated_img = gr.Image(type="pil", shape=(100, 100))
237243
std_output = gr.Textbox(
@@ -261,13 +267,12 @@ def debug_event(debug):
261267
seed,
262268
precision,
263269
device,
264-
load_vmfb,
265-
save_vmfb,
270+
cache,
266271
iree_vulkan_target_triple,
267272
live_preview,
268273
],
269274
outputs=[generated_img, std_output],
270275
)
271276

272277
shark_web.queue()
273-
shark_web.launch(share=True, server_port=8080, enable_queue=True)
278+
shark_web.launch(server_port=8080, enable_queue=True)
File renamed without changes.

web/logos/other_logo.png

32.9 KB
Loading

web/models/stable_diffusion/main.py

Lines changed: 47 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
IREE_EXTRA_ARGS = []
2424
args = None
25-
DEBUG = False
2625

2726

2827
class Arguments:
@@ -39,8 +38,7 @@ def __init__(
3938
seed: int,
4039
precision: str,
4140
device: str,
42-
load_vmfb: bool,
43-
save_vmfb: bool,
41+
cache: bool,
4442
iree_vulkan_target_triple: str,
4543
live_preview: bool,
4644
import_mlir: bool = False,
@@ -57,8 +55,7 @@ def __init__(
5755
self.seed = seed
5856
self.precision = precision
5957
self.device = device
60-
self.load_vmfb = load_vmfb
61-
self.save_vmfb = save_vmfb
58+
self.cache = cache
6259
self.iree_vulkan_target_triple = iree_vulkan_target_triple
6360
self.live_preview = live_preview
6461
self.import_mlir = import_mlir
@@ -101,6 +98,37 @@ def get_models():
10198
return None, None
10299

103100

101+
schedulers = dict()
102+
# set scheduler value
103+
schedulers["PNDM"] = PNDMScheduler(
104+
beta_start=0.00085,
105+
beta_end=0.012,
106+
beta_schedule="scaled_linear",
107+
num_train_timesteps=1000,
108+
)
109+
schedulers["LMS"] = LMSDiscreteScheduler(
110+
beta_start=0.00085,
111+
beta_end=0.012,
112+
beta_schedule="scaled_linear",
113+
num_train_timesteps=1000,
114+
)
115+
schedulers["DDIM"] = DDIMScheduler(
116+
beta_start=0.00085,
117+
beta_end=0.012,
118+
beta_schedule="scaled_linear",
119+
clip_sample=False,
120+
set_alpha_to_one=False,
121+
)
122+
123+
cache_obj = dict()
124+
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
125+
"openai/clip-vit-large-patch14"
126+
)
127+
cache_obj["text_encoder"] = CLIPTextModel.from_pretrained(
128+
"openai/clip-vit-large-patch14"
129+
)
130+
131+
104132
def stable_diff_inf(
105133
prompt: str,
106134
scheduler: str,
@@ -113,21 +141,17 @@ def stable_diff_inf(
113141
seed: str,
114142
precision: str,
115143
device: str,
116-
load_vmfb: bool,
117-
save_vmfb: bool,
144+
cache: bool,
118145
iree_vulkan_target_triple: str,
119146
live_preview: bool,
120147
):
121148

122149
global IREE_EXTRA_ARGS
123150
global args
124-
global DEBUG
151+
global schedulers
152+
global cache_obj
125153

126-
output_loc = f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{precision}_{device}.jpg"
127-
DEBUG = False
128-
log_write = open(r"logs/stable_diffusion_log.txt", "w")
129-
if log_write:
130-
DEBUG = True
154+
output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg"
131155

132156
# set seed value
133157
if seed == "":
@@ -138,34 +162,7 @@ def stable_diff_inf(
138162
except ValueError:
139163
seed = hash(seed)
140164

141-
# set scheduler value
142-
if scheduler == "PNDM":
143-
scheduler = PNDMScheduler(
144-
beta_start=0.00085,
145-
beta_end=0.012,
146-
beta_schedule="scaled_linear",
147-
num_train_timesteps=1000,
148-
)
149-
elif scheduler == "LMS":
150-
scheduler = LMSDiscreteScheduler(
151-
beta_start=0.00085,
152-
beta_end=0.012,
153-
beta_schedule="scaled_linear",
154-
num_train_timesteps=1000,
155-
)
156-
elif scheduler == "DDIM":
157-
scheduler = DDIMScheduler(
158-
beta_start=0.00085,
159-
beta_end=0.012,
160-
beta_schedule="scaled_linear",
161-
clip_sample=False,
162-
set_alpha_to_one=False,
163-
)
164-
else:
165-
raise Exception(
166-
f"Does not support scheduler with name {args.scheduler}."
167-
)
168-
165+
scheduler = schedulers[scheduler]
169166
args = Arguments(
170167
prompt,
171168
scheduler,
@@ -178,8 +175,7 @@ def stable_diff_inf(
178175
seed,
179176
precision,
180177
device,
181-
load_vmfb,
182-
save_vmfb,
178+
cache,
183179
iree_vulkan_target_triple,
184180
live_preview,
185181
)
@@ -194,11 +190,8 @@ def stable_diff_inf(
194190
) # Seed generator to create the inital latent noise
195191

196192
vae, unet = get_models()
197-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
198-
text_encoder = CLIPTextModel.from_pretrained(
199-
"openai/clip-vit-large-patch14"
200-
)
201-
193+
tokenizer = cache_obj["tokenizer"]
194+
text_encoder = cache_obj["text_encoder"]
202195
text_input = tokenizer(
203196
[args.prompt] * batch_size,
204197
padding="max_length",
@@ -233,10 +226,10 @@ def stable_diff_inf(
233226

234227
avg_ms = 0
235228
out_img = None
229+
text_output = ""
236230
for i, t in tqdm(enumerate(scheduler.timesteps)):
237231

238-
if DEBUG:
239-
log_write.write(f"\ni = {i} t = {t} ")
232+
text_output = text_output + f"\ni = {i} t = {t} "
240233
step_start = time.time()
241234
timestep = torch.tensor([t]).to(dtype).detach().numpy()
242235
latents_numpy = latents.detach().numpy()
@@ -249,8 +242,7 @@ def stable_diff_inf(
249242
step_time = time.time() - step_start
250243
avg_ms += step_time
251244
step_ms = int((step_time) * 1000)
252-
if DEBUG:
253-
log_write.write(f"time={step_ms}ms")
245+
text_output = text_output + f"time={step_ms}ms"
254246
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
255247

256248
if live_preview:
@@ -263,7 +255,7 @@ def stable_diff_inf(
263255
images = (image * 255).round().astype("uint8")
264256
pil_images = [Image.fromarray(image) for image in images]
265257
out_img = pil_images[0]
266-
yield out_img, ""
258+
yield out_img, text_output
267259

268260
# scale and decode the image latents with vae
269261
if not live_preview:
@@ -277,14 +269,8 @@ def stable_diff_inf(
277269
out_img = pil_images[0]
278270

279271
avg_ms = 1000 * avg_ms / args.steps
280-
if DEBUG:
281-
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
272+
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"
282273

283274
# save the output image with the prompt name.
284275
out_img.save(os.path.join(output_loc))
285-
log_write.close()
286-
287-
std_output = ""
288-
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
289-
std_output = log_read.read()
290-
yield out_img, std_output
276+
yield out_img, text_output

web/models/stable_diffusion/utils.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,16 @@
77

88

99
def _compile_module(args, shark_module, model_name, extra_args=[]):
10-
if args.load_vmfb or args.save_vmfb:
11-
extended_name = "{}_{}".format(model_name, args.device)
10+
extended_name = "{}_{}".format(model_name, args.device)
11+
if args.cache:
1212
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
13-
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
13+
if os.path.isfile(vmfb_path):
1414
print("Loading flatbuffer from {}".format(vmfb_path))
1515
shark_module.load_module(vmfb_path)
16-
else:
17-
if args.save_vmfb:
18-
print("Saving to {}".format(vmfb_path))
19-
else:
20-
print(
21-
"No vmfb found. Compiling and saving to {}".format(
22-
vmfb_path
23-
)
24-
)
25-
path = shark_module.save_module(
26-
os.getcwd(), extended_name, extra_args
27-
)
28-
shark_module.load_module(path)
29-
else:
30-
shark_module.compile(extra_args)
16+
return shark_module
17+
print("No vmfb found. Compiling and saving to {}".format(vmfb_path))
18+
path = shark_module.save_module(os.getcwd(), extended_name, extra_args)
19+
shark_module.load_module(path)
3120
return shark_module
3221

3322

0 commit comments

Comments
 (0)