Skip to content

Commit 18689af

Browse files
author
Prashant Kumar
committed
Make separate function for each model.
1 parent 64d6da7 commit 18689af

File tree

3 files changed

+109
-112
lines changed

3 files changed

+109
-112
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 2 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,10 @@
55
from tqdm.auto import tqdm
66
import numpy as np
77
from stable_args import args
8-
from model_wrappers import (
9-
get_vae32,
10-
get_vae16,
11-
get_unet16_wrapped,
12-
get_unet32_wrapped,
13-
get_clipped_text,
14-
)
158
from utils import get_shark_model
9+
from opt_params import get_unet, get_vae, get_clip
1610
import time
1711

18-
GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
19-
VAE_FP16 = "vae_fp16"
20-
VAE_FP32 = "vae_fp32"
21-
UNET_FP16 = "unet_fp16"
22-
UNET_FP32 = "unet_fp32"
23-
IREE_EXTRA_ARGS = []
24-
25-
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
26-
UNET_FP16_TUNED = "unet_fp16_tunedv2"
27-
28-
BATCH_SIZE = len(args.prompts)
29-
30-
if BATCH_SIZE not in [1, 2]:
31-
import sys
32-
33-
sys.exit("Only batch size 1 and 2 are supported.")
34-
35-
if BATCH_SIZE > 1 and args.precision != "fp16":
36-
sys.exit("batch size > 1 is supported for fp16 model.")
37-
38-
39-
if BATCH_SIZE != 1:
40-
TUNED_GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
41-
UNET_FP16_TUNED = f"unet_fp16_{BATCH_SIZE}"
42-
VAE_FP16 = f"vae_fp16_{BATCH_SIZE}"
4312

4413
# Helper function to profile the vulkan device.
4514
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
@@ -58,87 +27,9 @@ def end_profiling(device):
5827
return device.end_profiling()
5928

6029

61-
def get_models():
62-
global IREE_EXTRA_ARGS
63-
if args.precision == "fp16":
64-
IREE_EXTRA_ARGS += [
65-
"--iree-flow-enable-padding-linalg-ops",
66-
"--iree-flow-linalg-ops-padding-size=32",
67-
]
68-
if args.use_tuned:
69-
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
70-
vae_gcloud_bucket = GCLOUD_BUCKET
71-
unet_args = IREE_EXTRA_ARGS
72-
vae_args = IREE_EXTRA_ARGS + [
73-
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
74-
]
75-
unet_name = UNET_FP16_TUNED
76-
vae_name = VAE_FP16
77-
else:
78-
unet_gcloud_bucket = GCLOUD_BUCKET
79-
vae_gcloud_bucket = GCLOUD_BUCKET
80-
IREE_EXTRA_ARGS += [
81-
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
82-
]
83-
unet_args = IREE_EXTRA_ARGS
84-
vae_args = IREE_EXTRA_ARGS
85-
unet_name = UNET_FP16
86-
vae_name = VAE_FP16
87-
88-
if batch_size > 1:
89-
vae_args = []
90-
91-
if args.import_mlir == True:
92-
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
93-
model_name=UNET_FP16
94-
)
95-
else:
96-
return get_shark_model(
97-
vae_gcloud_bucket,
98-
vae_name,
99-
vae_args,
100-
), get_shark_model(
101-
unet_gcloud_bucket,
102-
unet_name,
103-
unet_args,
104-
)
105-
106-
elif args.precision == "fp32":
107-
IREE_EXTRA_ARGS += [
108-
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
109-
"--iree-flow-enable-padding-linalg-ops",
110-
"--iree-flow-linalg-ops-padding-size=16",
111-
]
112-
if args.import_mlir == True:
113-
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
114-
model_name=UNET_FP32
115-
)
116-
else:
117-
return get_shark_model(
118-
GCLOUD_BUCKET,
119-
VAE_FP32,
120-
IREE_EXTRA_ARGS,
121-
), get_shark_model(
122-
GCLOUD_BUCKET,
123-
UNET_FP32,
124-
IREE_EXTRA_ARGS,
125-
)
126-
127-
12830
if __name__ == "__main__":
12931

13032
dtype = torch.float32 if args.precision == "fp32" else torch.half
131-
if len(args.iree_vulkan_target_triple) > 0:
132-
IREE_EXTRA_ARGS.append(
133-
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
134-
)
135-
136-
clip_model = "clip_text"
137-
clip_extra_args = [
138-
"--iree-flow-linalg-ops-padding-size=16",
139-
"--iree-flow-enable-padding-linalg-ops",
140-
]
141-
clip = get_shark_model(GCLOUD_BUCKET, clip_model, clip_extra_args)
14233

14334
prompt = args.prompts
14435
height = 512 # default height of Stable Diffusion
@@ -154,7 +45,7 @@ def get_models():
15445

15546
batch_size = len(prompt)
15647

157-
vae, unet = get_models()
48+
unet, vae, clip = get_unet(), get_vae(), get_clip()
15849

15950
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
16051

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from model_wrappers import (
2+
get_vae32,
3+
get_vae16,
4+
get_unet16_wrapped,
5+
get_unet32_wrapped,
6+
get_clipped_text,
7+
)
8+
from stable_args import args
9+
from utils import get_shark_model
10+
11+
BATCH_SIZE = len(args.prompts)
12+
if BATCH_SIZE != 1:
13+
import sys
14+
15+
sys.exit("Only batch size 1 is supported.")
16+
17+
18+
def get_unet():
19+
iree_flags = []
20+
if len(args.iree_vulkan_target_triple) > 0:
21+
iree_flags.append(
22+
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
23+
)
24+
# Tuned model is present for `fp16` precision.
25+
if args.precision == "fp16":
26+
if args.use_tuned:
27+
bucket = "gs://shark_tank/quinn"
28+
model_name = "unet_fp16_tunedv2"
29+
iree_flags += [
30+
"--iree-flow-enable-padding-linalg-ops",
31+
"--iree-flow-linalg-ops-padding-size=32",
32+
]
33+
# TODO: Pass iree_flags to the exported model.
34+
if args.import_mlir:
35+
return get_unet16_wrapped(model_name=model_name)
36+
return get_shark_model(bucket, model_name, iree_flags)
37+
else:
38+
bucket = "gs://shark_tank/prashant_nod"
39+
model_name = "unet_fp16"
40+
iree_flags += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
41+
if args.import_mlir:
42+
return get_unet16_wrapped(model_name=model_name)
43+
return get_shark_model(bucket, model_name, iree_flags)
44+
45+
# Tuned model is not present for `fp32` case.
46+
if args.precision == "fp32":
47+
bucket = "gs://shark_tank/prashant_nod"
48+
model_name = "unet_fp32"
49+
iree_flags += [
50+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
51+
"--iree-flow-enable-padding-linalg-ops",
52+
"--iree-flow-linalg-ops-padding-size=16",
53+
]
54+
if args.import_mlir:
55+
return get_unet32_wrapped(model_name=model_name)
56+
return get_shark_model(bucket, model_name, iree_flags)
57+
58+
59+
def get_vae():
60+
iree_flags = []
61+
if len(args.iree_vulkan_target_triple) > 0:
62+
iree_flags.append(
63+
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
64+
)
65+
if args.precision == "fp16":
66+
bucket = "gs://shark_tank/prashant_nod"
67+
model_name = "vae_fp16"
68+
iree_flags += [
69+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
70+
"--iree-flow-enable-padding-linalg-ops",
71+
"--iree-flow-linalg-ops-padding-size=32",
72+
]
73+
if args.import_mlir:
74+
return get_vae16(model_name)
75+
return get_shark_model(bucket, model_name, iree_flags)
76+
77+
if args.precision == "fp32":
78+
bucket = "gs://shark_tank/prashant_nod"
79+
model_name = "vae_fp32"
80+
iree_flags += [
81+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
82+
"--iree-flow-enable-padding-linalg-ops",
83+
"--iree-flow-linalg-ops-padding-size=16",
84+
]
85+
if args.import_mlir:
86+
return get_vae32(model_name)
87+
return get_shark_model(bucket, model_name, iree_flags)
88+
89+
90+
def get_clip():
91+
iree_flags = []
92+
if len(args.iree_vulkan_target_triple) > 0:
93+
iree_flags.append(
94+
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
95+
)
96+
bucket = "gs://shark_tank/prashant_nod"
97+
model_name = "clip_text"
98+
iree_flags = [
99+
"--iree-flow-linalg-ops-padding-size=16",
100+
"--iree-flow-enable-padding-linalg-ops",
101+
]
102+
if args.import_mlir:
103+
return get_clipped_text(model_name)
104+
return get_shark_model(bucket, model_name, iree_flags)

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from torch.fx.experimental.proxy_tensor import make_fx
77
from stable_args import args
88
from torch._decomp import get_decompositions
9-
import torch_mlir
9+
10+
if args.import_mlir:
11+
import torch_mlir
1012

1113

1214
def _compile_module(shark_module, model_name, extra_args=[]):

0 commit comments

Comments
 (0)