Skip to content

Commit 25931d4

Browse files
author
Gaurav Shukla
authored
[WEB] Update stable diffusion UI and enable live preview (huggingface#447)
This commit enables live preview feature and also updates stable diffusion web UI. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]>
1 parent 024c5e1 commit 25931d4

File tree

3 files changed

+112
-64
lines changed

3 files changed

+112
-64
lines changed

shark/iree_utils/vulkan_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ def get_vulkan_triple_flag(extra_args=[]):
4848
elif all(x in vulkan_device for x in ("RTX", "3090")):
4949
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
5050
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
51-
elif any(x in vulkan_device for x in ("Radeon", "AMD")):
52-
print(
53-
"Found AMD Radeon RX 6000 series device. Using rdna2-unknown-linux"
54-
)
51+
elif "AMD" in vulkan_device:
52+
print("Found AMD device. Using rdna2-unknown-linux")
5553
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
5654
else:
5755
print(

web/index.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -141,67 +141,96 @@ def debug_event(debug):
141141
save_vmfb
142142
) = (
143143
iree_vulkan_target_triple
144+
) = (
145+
live_preview
144146
) = debug = stable_diffusion = generated_img = std_output = None
147+
examples = [
148+
["A high tech solarpunk utopia in the Amazon rainforest"],
149+
["A pikachu fine dining with a view to the Eiffel Tower"],
150+
["A mecha robot in a favela in expressionist style"],
151+
["an insect robot preparing a delicious meal"],
152+
[
153+
"A small cabin on top of a snowy mountain in the style of Disney, artstation"
154+
],
155+
]
156+
145157
with gr.Row():
146158
with gr.Column(scale=1, min_width=600):
147-
prompt = gr.Textbox(
148-
label="Prompt",
149-
value="a photograph of an astronaut riding a horse",
150-
lines=2,
151-
)
152-
scheduler = gr.Radio(
153-
label="Scheduler",
154-
value="LMS",
155-
choices=["PNDM", "LMS", "DDIM"],
156-
visible=False,
157-
)
158-
iters_count = gr.Slider(
159-
1,
160-
24,
161-
value=1,
162-
step=1,
163-
label="Iteration Count",
164-
visible=False,
165-
)
166-
batch_size = gr.Slider(
167-
1,
168-
4,
169-
value=1,
170-
step=1,
171-
label="Batch Size",
172-
visible=False,
173-
)
174-
steps = gr.Slider(1, 100, value=20, step=1, label="Steps")
175-
guidance = gr.Slider(
176-
0, 50, value=7.5, step=0.1, label="Guidance Scale"
177-
)
178-
height = gr.Slider(
179-
384, 768, value=512, step=64, label="Height"
180-
)
181-
width = gr.Slider(
182-
384, 768, value=512, step=64, label="Width"
183-
)
184-
seed = gr.Textbox(value="42", max_lines=1, label="Seed")
185-
precision = gr.Radio(
186-
label="Precision",
187-
value="fp32",
188-
choices=["fp16", "fp32"],
189-
)
190-
device = gr.Radio(
191-
label="Device",
192-
value="vulkan",
193-
choices=["cpu", "cuda", "vulkan"],
194-
)
159+
with gr.Group():
160+
prompt = gr.Textbox(
161+
label="Prompt",
162+
value="a photograph of an astronaut riding a horse",
163+
)
164+
ex = gr.Examples(
165+
examples=examples,
166+
inputs=prompt,
167+
cache_examples=False,
168+
)
169+
with gr.Row():
170+
iters_count = gr.Slider(
171+
1,
172+
24,
173+
value=1,
174+
step=1,
175+
label="Iteration Count",
176+
visible=False,
177+
)
178+
batch_size = gr.Slider(
179+
1,
180+
4,
181+
value=1,
182+
step=1,
183+
label="Batch Size",
184+
visible=False,
185+
)
186+
with gr.Row():
187+
steps = gr.Slider(
188+
1, 100, value=20, step=1, label="Steps"
189+
)
190+
guidance = gr.Slider(
191+
0, 50, value=7.5, step=0.1, label="Guidance Scale"
192+
)
193+
with gr.Row():
194+
height = gr.Slider(
195+
384, 768, value=512, step=64, label="Height"
196+
)
197+
width = gr.Slider(
198+
384, 768, value=512, step=64, label="Width"
199+
)
200+
with gr.Row():
201+
precision = gr.Radio(
202+
label="Precision",
203+
value="fp32",
204+
choices=["fp16", "fp32"],
205+
)
206+
device = gr.Radio(
207+
label="Device",
208+
value="vulkan",
209+
choices=["cpu", "cuda", "vulkan"],
210+
)
211+
with gr.Row():
212+
scheduler = gr.Radio(
213+
label="Scheduler",
214+
value="LMS",
215+
choices=["PNDM", "LMS", "DDIM"],
216+
interactive=False,
217+
)
218+
seed = gr.Textbox(
219+
value="42", max_lines=1, label="Seed"
220+
)
195221
with gr.Row():
196222
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
197223
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
224+
debug = gr.Checkbox(label="DEBUG", value=False)
225+
live_preview = gr.Checkbox(
226+
label="live preview", value=False
227+
)
198228
iree_vulkan_target_triple = gr.Textbox(
199229
value="",
200230
max_lines=1,
201231
label="IREE VULKAN TARGET TRIPLE",
202232
visible=False,
203233
)
204-
debug = gr.Checkbox(label="DEBUG", value=False)
205234
stable_diffusion = gr.Button("Generate image from prompt")
206235
with gr.Column(scale=1, min_width=600):
207236
generated_img = gr.Image(type="pil", shape=(100, 100))
@@ -211,6 +240,7 @@ def debug_event(debug):
211240
lines=10,
212241
visible=False,
213242
)
243+
214244
debug.change(
215245
debug_event,
216246
inputs=[debug],
@@ -234,8 +264,10 @@ def debug_event(debug):
234264
load_vmfb,
235265
save_vmfb,
236266
iree_vulkan_target_triple,
267+
live_preview,
237268
],
238269
outputs=[generated_img, std_output],
239270
)
240271

272+
shark_web.queue()
241273
shark_web.launch(share=True, server_port=8080, enable_queue=True)

web/models/stable_diffusion/main.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
load_vmfb: bool,
4343
save_vmfb: bool,
4444
iree_vulkan_target_triple: str,
45+
live_preview: bool,
4546
import_mlir: bool = False,
4647
max_length: int = 77,
4748
):
@@ -59,6 +60,7 @@ def __init__(
5960
self.load_vmfb = load_vmfb
6061
self.save_vmfb = save_vmfb
6162
self.iree_vulkan_target_triple = iree_vulkan_target_triple
63+
self.live_preview = live_preview
6264
self.import_mlir = import_mlir
6365
self.max_length = max_length
6466

@@ -114,6 +116,7 @@ def stable_diff_inf(
114116
load_vmfb: bool,
115117
save_vmfb: bool,
116118
iree_vulkan_target_triple: str,
119+
live_preview: bool,
117120
):
118121

119122
global IREE_EXTRA_ARGS
@@ -178,6 +181,7 @@ def stable_diff_inf(
178181
load_vmfb,
179182
save_vmfb,
180183
iree_vulkan_target_triple,
184+
live_preview,
181185
)
182186
dtype = torch.float32 if args.precision == "fp32" else torch.half
183187
if len(args.iree_vulkan_target_triple) > 0:
@@ -228,6 +232,7 @@ def stable_diff_inf(
228232
text_embeddings_numpy = text_embeddings.detach().numpy()
229233

230234
avg_ms = 0
235+
out_img = None
231236
for i, t in tqdm(enumerate(scheduler.timesteps)):
232237

233238
if DEBUG:
@@ -248,25 +253,38 @@ def stable_diff_inf(
248253
log_write.write(f"time={step_ms}ms")
249254
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
250255

256+
if live_preview:
257+
time.sleep(0.1)
258+
scaled_latents = 1 / 0.18215 * latents
259+
latents_numpy = scaled_latents.detach().numpy()
260+
image = vae.forward((latents_numpy,))
261+
image = torch.from_numpy(image)
262+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
263+
images = (image * 255).round().astype("uint8")
264+
pil_images = [Image.fromarray(image) for image in images]
265+
out_img = pil_images[0]
266+
yield out_img, ""
267+
251268
# scale and decode the image latents with vae
252-
latents = 1 / 0.18215 * latents
253-
latents_numpy = latents.detach().numpy()
254-
image = vae.forward((latents_numpy,))
255-
image = torch.from_numpy(image)
256-
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
257-
images = (image * 255).round().astype("uint8")
258-
pil_images = [Image.fromarray(image) for image in images]
269+
if not live_preview:
270+
latents = 1 / 0.18215 * latents
271+
latents_numpy = latents.detach().numpy()
272+
image = vae.forward((latents_numpy,))
273+
image = torch.from_numpy(image)
274+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
275+
images = (image * 255).round().astype("uint8")
276+
pil_images = [Image.fromarray(image) for image in images]
277+
out_img = pil_images[0]
278+
259279
avg_ms = 1000 * avg_ms / args.steps
260280
if DEBUG:
261281
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
262282

263-
print("total images:", len(pil_images))
264-
output = pil_images[0]
265283
# save the output image with the prompt name.
266-
output.save(os.path.join(output_loc))
284+
out_img.save(os.path.join(output_loc))
267285
log_write.close()
268286

269287
std_output = ""
270288
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
271289
std_output = log_read.read()
272-
return output, std_output
290+
yield out_img, std_output

0 commit comments

Comments
 (0)