Skip to content

Commit 88937fc

Browse files
author
Gaurav Shukla
committed
[WEB] Add vulkan-heap-block-size flag
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent f80b85f commit 88937fc

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

web/models/stable_diffusion/arguments.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from transformers import CLIPTokenizer
22
from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
33
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
4+
from models.stable_diffusion.utils import set_iree_runtime_flags
5+
import os
46

57

68
class Arguments:
@@ -24,6 +26,7 @@ def __init__(
2426
import_mlir: bool = False,
2527
max_length: int = 77,
2628
use_tuned: bool = False,
29+
vulkan_large_heap_block_size: int = 4294967296,
2730
):
2831
self.prompt = prompt
2932
self.scheduler = scheduler
@@ -43,6 +46,7 @@ def __init__(
4346
self.import_mlir = import_mlir
4447
self.max_length = max_length
4548
self.use_tuned = use_tuned
49+
self.vulkan_large_heap_block_size = vulkan_large_heap_block_size
4650

4751
def set_params(
4852
self,
@@ -81,6 +85,9 @@ def set_params(
8185
self.import_mlir = import_mlir
8286

8387

88+
output_dir = "./stored_results/stable_diffusion"
89+
os.makedirs(output_dir, exist_ok=True)
90+
8491
schedulers = dict()
8592
# set scheduler value
8693
schedulers["PNDM"] = PNDMScheduler(
@@ -112,6 +119,7 @@ def set_params(
112119
# cache vae and unet.
113120
args = Arguments()
114121
args.device = "vulkan"
122+
set_iree_runtime_flags(args)
115123
cache_obj["vae_fp16_vulkan"], cache_obj["unet_fp16_vulkan"] = get_vae(
116124
args
117125
), get_unet(args)

web/models/stable_diffusion/main.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from numpy import iinfo
1111
from random import randint
1212
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
13-
from models.stable_diffusion.arguments import args, schedulers, cache_obj
14-
15-
output_dir = "./stored_results/stable_diffusion"
16-
os.makedirs(output_dir, exist_ok=True)
13+
from models.stable_diffusion.arguments import (
14+
args,
15+
schedulers,
16+
cache_obj,
17+
output_dir,
18+
)
1719

1820

1921
def stable_diff_inf(
@@ -153,6 +155,7 @@ def stable_diff_inf(
153155
avg_ms += step_time
154156
step_ms = int((step_time) * 1000)
155157
text_output += f"Time = {step_ms}ms."
158+
print(f" \nIteration = {i}, Time = {step_ms}ms")
156159
latents = scheduler_obj.step(noise_pred, i, latents)["prev_sample"]
157160

158161
if live_preview and i % 5 == 0:
@@ -178,6 +181,7 @@ def stable_diff_inf(
178181

179182
avg_ms = 1000 * avg_ms / args.steps
180183
text_output += f"\n\nAverage step time: {avg_ms}ms/it"
184+
print(f"\n\nAverage step time: {avg_ms}ms/it")
181185

182186
total_time = time.time() - start
183187
text_output += f"\n\nTotal image generation time: {total_time}sec"

web/models/stable_diffusion/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
import torch
44
from shark.shark_inference import SharkInference
55
from shark.shark_importer import import_with_fx
6+
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
7+
8+
9+
def set_iree_runtime_flags(args):
10+
vulkan_runtime_flags = [
11+
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
12+
]
13+
if "vulkan" in args.device:
14+
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
15+
return
616

717

818
def _compile_module(args, shark_module, model_name, extra_args=[]):

0 commit comments

Comments
 (0)