Skip to content

Commit 9a36816

Browse files
authored
[SD][CLI] Add a warmup phase (huggingface#670)
1 parent 7986b9b commit 9a36816

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def end_profiling(device):
118118
subfolder="scheduler",
119119
)
120120

121+
latents = torch.randn(
122+
(batch_size, 4, height // 8, width // 8),
123+
generator=generator,
124+
dtype=torch.float32,
125+
).to(dtype)
126+
# Warmup phase to improve performance.
127+
if args.warmup_count >= 1:
128+
vae_warmup_input = torch.clone(latents).detach().numpy()
129+
clip_warmup_input = torch.randint(1, 2, (2, 77))
130+
for i in range(args.warmup_count):
131+
vae.forward((vae_warmup_input,))
132+
clip.forward((clip_warmup_input,))
133+
121134
start = time.time()
122135

123136
text_input = tokenizer(

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,11 @@
178178
help="flag for hiding the details of iteration/sec for each step.",
179179
)
180180

181+
p.add_argument(
182+
"--warmup_count",
183+
type=int,
184+
default=0,
185+
help="flag setting warmup count for clip and vae [>= 0].",
186+
)
187+
181188
args = p.parse_args()

0 commit comments

Comments
 (0)