Skip to content

Commit 32e1ba8

Browse files
author
Prashant Kumar
committed
Adding batch_size support for stable diffusion.
1 parent 1939376 commit 32e1ba8

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@
2424
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
2525
UNET_FP16_TUNED = "unet_fp16_tuned"
2626

27+
BATCH_SIZE = len(args.prompts)
28+
29+
if BATCH_SIZE not in [1, 2]:
30+
import sys
31+
32+
sys.exit("Only batch size 1 and 2 are supported.")
33+
34+
if BATCH_SIZE > 1 and args.precision != "fp16":
35+
sys.exit("batch size > 1 is supported for fp16 model.")
36+
37+
38+
if BATCH_SIZE != 1:
39+
TUNED_GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
40+
UNET_FP16_TUNED = f"unet_fp16_{BATCH_SIZE}"
41+
VAE_FP16 = f"vae_fp16_{BATCH_SIZE}"
42+
2743
# Helper function to profile the vulkan device.
2844
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
2945
if args.vulkan_debug_utils and "vulkan" in args.device:
@@ -67,6 +83,10 @@ def get_models():
6783
vae_args = IREE_EXTRA_ARGS
6884
unet_name = UNET_FP16
6985
vae_name = VAE_FP16
86+
87+
if batch_size > 1:
88+
vae_args = []
89+
7090
if args.import_mlir == True:
7191
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
7292
model_name=UNET_FP16
@@ -112,8 +132,7 @@ def get_models():
112132
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
113133
)
114134

115-
prompt = [args.prompt]
116-
135+
prompt = args.prompts
117136
height = 512 # default height of Stable Diffusion
118137
width = 512 # default width of Stable Diffusion
119138

@@ -211,4 +230,5 @@ def get_models():
211230
print("Total image generation runtime (s): {}".format(time.time() - start))
212231

213232
pil_images = [Image.fromarray(image) for image in images]
214-
pil_images[0].save(f"{args.prompt}.jpg")
233+
for i in range(batch_size):
234+
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
YOUR_TOKEN = "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"
77

88

9+
BATCH_SIZE = len(args.prompts)
10+
11+
912
def get_vae32(model_name="vae_fp32"):
1013
class VaeModel(torch.nn.Module):
1114
def __init__(self):
@@ -21,7 +24,7 @@ def forward(self, input):
2124
return (x / 2 + 0.5).clamp(0, 1)
2225

2326
vae = VaeModel()
24-
vae_input = torch.rand(1, 4, 64, 64)
27+
vae_input = torch.rand(BATCH_SIZE, 4, 64, 64)
2528
shark_vae = compile_through_fx(
2629
vae,
2730
(vae_input,),
@@ -47,7 +50,7 @@ def forward(self, input):
4750

4851
vae = VaeModel()
4952
vae = vae.half().cuda()
50-
vae_input = torch.rand(1, 4, 64, 64, dtype=torch.half).cuda()
53+
vae_input = torch.rand(BATCH_SIZE, 4, 64, 64, dtype=torch.half).cuda()
5154
shark_vae = compile_through_fx(
5255
vae,
5356
(vae_input,),
@@ -143,8 +146,10 @@ def forward(self, latent, timestep, text_embedding, sigma):
143146

144147
unet = UnetModel()
145148
unet = unet.half().cuda()
146-
latent_model_input = torch.rand([1, 4, 64, 64]).half().cuda()
147-
text_embeddings = torch.rand([2, args.max_length, 768]).half().cuda()
149+
latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64]).half().cuda()
150+
text_embeddings = (
151+
torch.rand([2 * BATCH_SIZE, args.max_length, 768]).half().cuda()
152+
)
148153
sigma = torch.tensor(1).to(torch.float32)
149154
shark_unet = compile_through_fx(
150155
unet,
@@ -185,8 +190,8 @@ def forward(self, latent, timestep, text_embedding, sigma):
185190
return noise_pred
186191

187192
unet = UnetModel()
188-
latent_model_input = torch.rand([1, 4, 64, 64])
189-
text_embeddings = torch.rand([2, args.max_length, 768])
193+
latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64])
194+
text_embeddings = torch.rand([2 * BATCH_SIZE, args.max_length, 768])
190195
sigma = torch.tensor(1).to(torch.float32)
191196
shark_unet = compile_through_fx(
192197
unet,

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
)
66

77
p.add_argument(
8-
"--prompt",
9-
type=str,
10-
default="a photograph of an astronaut riding a horse",
11-
help="the text to generate image of.",
8+
"--prompts",
9+
nargs="+",
10+
default=["a photograph of an astronaut riding a horse"],
11+
help="text of which images to be generated.",
1212
)
1313
p.add_argument(
1414
"--device", type=str, default="cpu", help="device to run the model."

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def strip_overloads(gm):
9191
frontend="torch",
9292
)
9393

94-
mlir_module, func_name = mlir_importer.import_mlir()
94+
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
9595

9696
shark_module = SharkInference(
9797
mlir_module,

0 commit comments

Comments
 (0)