Skip to content

Commit 6eb47c1

Browse files
authored
add multi-run in single execution (huggingface#812)
1 parent 5a1fc66 commit 6eb47c1

File tree

2 files changed

+157
-142
lines changed

2 files changed

+157
-142
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 150 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,6 @@ def end_profiling(device):
8686
# Scale for classifier-free guidance
8787
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
8888

89-
# Handle out of range seeds.
90-
uint32_info = np.iinfo(np.uint32)
91-
uint32_min, uint32_max = uint32_info.min, uint32_info.max
92-
seed = args.seed
93-
if seed < uint32_min or seed >= uint32_max:
94-
seed = randint(uint32_min, uint32_max)
95-
generator = torch.manual_seed(
96-
seed
97-
) # Seed generator to create the inital latent noise
98-
9989
# TODO: Add support for batch_size > 1.
10090
batch_size = len(prompt)
10191
if batch_size != 1:
@@ -144,139 +134,157 @@ def end_profiling(device):
144134
"stabilityai/stable-diffusion-2-1-base",
145135
subfolder="scheduler",
146136
)
137+
for run in range(args.runs):
138+
# Handle out of range seeds.
139+
uint32_info = np.iinfo(np.uint32)
140+
uint32_min, uint32_max = uint32_info.min, uint32_info.max
141+
seed = args.seed
142+
if run >= 1 or seed < uint32_min or seed >= uint32_max:
143+
seed = randint(uint32_min, uint32_max)
144+
generator = torch.manual_seed(
145+
seed
146+
) # Seed generator to create the inital latent noise
147+
148+
# create a random initial latent.
149+
latents = torch.randn(
150+
(batch_size, 4, height // 8, width // 8),
151+
generator=generator,
152+
dtype=torch.float32,
153+
).to(dtype)
154+
if run == 0:
155+
# Warmup phase to improve performance.
156+
if args.warmup_count >= 1:
157+
vae_warmup_input = torch.clone(latents).detach().numpy()
158+
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
159+
for i in range(args.warmup_count):
160+
vae("forward", (vae_warmup_input,))
161+
clip("forward", (clip_warmup_input,))
162+
163+
start = time.time()
164+
if run == 0:
165+
text_input = tokenizer(
166+
prompt,
167+
padding="max_length",
168+
max_length=args.max_length,
169+
truncation=True,
170+
return_tensors="pt",
171+
)
172+
max_length = text_input.input_ids.shape[-1]
173+
uncond_input = tokenizer(
174+
neg_prompt,
175+
padding="max_length",
176+
max_length=max_length,
177+
truncation=True,
178+
return_tensors="pt",
179+
)
180+
text_input = torch.cat(
181+
[uncond_input.input_ids, text_input.input_ids]
182+
)
147183

148-
# create a random initial latent.
149-
latents = torch.randn(
150-
(batch_size, 4, height // 8, width // 8),
151-
generator=generator,
152-
dtype=torch.float32,
153-
).to(dtype)
154-
# Warmup phase to improve performance.
155-
if args.warmup_count >= 1:
156-
vae_warmup_input = torch.clone(latents).detach().numpy()
157-
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
158-
for i in range(args.warmup_count):
159-
vae("forward", (vae_warmup_input,))
160-
clip("forward", (clip_warmup_input,))
161-
162-
start = time.time()
163-
164-
text_input = tokenizer(
165-
prompt,
166-
padding="max_length",
167-
max_length=args.max_length,
168-
truncation=True,
169-
return_tensors="pt",
170-
)
171-
max_length = text_input.input_ids.shape[-1]
172-
uncond_input = tokenizer(
173-
neg_prompt,
174-
padding="max_length",
175-
max_length=max_length,
176-
truncation=True,
177-
return_tensors="pt",
178-
)
179-
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
180-
181-
clip_inf_start = time.time()
182-
text_embeddings = clip("forward", (text_input,))
183-
clip_inf_end = time.time()
184-
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
185-
text_embeddings_numpy = text_embeddings.detach().numpy()
186-
187-
scheduler.set_timesteps(num_inference_steps)
188-
scheduler.is_scale_input_called = True
189-
190-
latents = latents * scheduler.init_noise_sigma
191-
192-
avg_ms = 0
193-
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
194-
step_start = time.time()
195-
if not args.hide_steps:
196-
print(f"i = {i} t = {t}", end="")
197-
timestep = torch.tensor([t]).to(dtype).detach().numpy()
198-
latent_model_input = scheduler.scale_model_input(latents, t)
199-
if cpu_scheduling:
200-
latent_model_input = latent_model_input.detach().numpy()
201-
202-
profile_device = start_profiling(file_path="unet.rdc")
203-
204-
noise_pred = unet(
205-
"forward",
206-
(
207-
latent_model_input,
208-
timestep,
209-
text_embeddings_numpy,
210-
guidance_scale,
211-
),
212-
send_to_host=False,
213-
)
214-
215-
end_profiling(profile_device)
184+
clip_inf_start = time.time()
185+
text_embeddings = clip("forward", (text_input,))
186+
clip_inf_end = time.time()
187+
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
188+
text_embeddings_numpy = text_embeddings.detach().numpy()
189+
190+
scheduler.set_timesteps(num_inference_steps)
191+
scheduler.is_scale_input_called = True
192+
193+
latents = latents * scheduler.init_noise_sigma
194+
195+
avg_ms = 0
196+
for i, t in tqdm(
197+
enumerate(scheduler.timesteps), disable=args.hide_steps
198+
):
199+
step_start = time.time()
200+
if not args.hide_steps:
201+
print(f"i = {i} t = {t}", end="")
202+
timestep = torch.tensor([t]).to(dtype).detach().numpy()
203+
latent_model_input = scheduler.scale_model_input(latents, t)
204+
if cpu_scheduling:
205+
latent_model_input = latent_model_input.detach().numpy()
206+
207+
profile_device = start_profiling(file_path="unet.rdc")
208+
209+
noise_pred = unet(
210+
"forward",
211+
(
212+
latent_model_input,
213+
timestep,
214+
text_embeddings_numpy,
215+
guidance_scale,
216+
),
217+
send_to_host=False,
218+
)
216219

220+
end_profiling(profile_device)
221+
222+
if cpu_scheduling:
223+
noise_pred = torch.from_numpy(noise_pred.to_host())
224+
latents = scheduler.step(noise_pred, t, latents).prev_sample
225+
else:
226+
latents = scheduler.step(noise_pred, t, latents)
227+
step_time = time.time() - step_start
228+
avg_ms += step_time
229+
step_ms = int((step_time) * 1000)
230+
if not args.hide_steps:
231+
print(f" ({step_ms}ms)")
232+
233+
# scale and decode the image latents with vae
234+
if args.use_base_vae:
235+
latents = 1 / 0.18215 * latents
236+
latents_numpy = latents
217237
if cpu_scheduling:
218-
noise_pred = torch.from_numpy(noise_pred.to_host())
219-
latents = scheduler.step(noise_pred, t, latents).prev_sample
238+
latents_numpy = latents.detach().numpy()
239+
profile_device = start_profiling(file_path="vae.rdc")
240+
vae_start = time.time()
241+
images = vae("forward", (latents_numpy,))
242+
vae_end = time.time()
243+
end_profiling(profile_device)
244+
if args.use_base_vae:
245+
image = torch.from_numpy(images)
246+
image = (image.detach().cpu() * 255.0).numpy()
247+
images = image.round()
248+
end_time = time.time()
249+
250+
avg_ms = 1000 * avg_ms / args.steps
251+
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
252+
vae_inf_time = (vae_end - vae_start) * 1000
253+
total_time = end_time - start
254+
255+
print(f"\nStats for run {run}:")
256+
print(f"Average step time: {avg_ms}ms/it")
257+
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
258+
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
259+
print(f"\nTotal image generation time: {total_time}sec")
260+
261+
transform = T.ToPILImage()
262+
pil_images = [
263+
transform(image)
264+
for image in torch.from_numpy(images).to(torch.uint8)
265+
]
266+
267+
if args.output_dir is not None:
268+
output_path = Path(args.output_dir)
269+
output_path.mkdir(parents=True, exist_ok=True)
220270
else:
221-
latents = scheduler.step(noise_pred, t, latents)
222-
step_time = time.time() - step_start
223-
avg_ms += step_time
224-
step_ms = int((step_time) * 1000)
225-
if not args.hide_steps:
226-
print(f" ({step_ms}ms)")
227-
228-
# scale and decode the image latents with vae
229-
if args.use_base_vae:
230-
latents = 1 / 0.18215 * latents
231-
latents_numpy = latents
232-
if cpu_scheduling:
233-
latents_numpy = latents.detach().numpy()
234-
profile_device = start_profiling(file_path="vae.rdc")
235-
vae_start = time.time()
236-
images = vae("forward", (latents_numpy,))
237-
vae_end = time.time()
238-
end_profiling(profile_device)
239-
if args.use_base_vae:
240-
image = torch.from_numpy(images)
241-
image = (image.detach().cpu() * 255.0).numpy()
242-
images = image.round()
243-
end_time = time.time()
244-
245-
avg_ms = 1000 * avg_ms / args.steps
246-
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
247-
vae_inf_time = (vae_end - vae_start) * 1000
248-
total_time = end_time - start
249-
print(f"\nAverage step time: {avg_ms}ms/it")
250-
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
251-
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
252-
print(f"\nTotal image generation time: {total_time}sec")
253-
254-
transform = T.ToPILImage()
255-
pil_images = [
256-
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
257-
]
258-
259-
if args.output_dir is not None:
260-
output_path = Path(args.output_dir)
261-
output_path.mkdir(parents=True, exist_ok=True)
262-
else:
263-
output_path = Path.cwd()
264-
disk_space_check(output_path, lim=5)
265-
for i in range(batch_size):
266-
json_store = {
267-
"prompt": args.prompts[i],
268-
"negative prompt": args.negative_prompts[i],
269-
"seed": args.seed,
270-
"variant": args.variant,
271-
"precision": args.precision,
272-
"steps": args.steps,
273-
"guidance_scale": args.guidance_scale,
274-
"scheduler": args.scheduler,
275-
}
276-
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[i][:15])
277-
img_name = f"{prompt_slice}_{args.seed}_{i}_{dt.now().strftime('%y%m%d_%H%M%S')}"
278-
pil_images[i].save(
279-
output_path / f"{img_name}.jpg", quality=95, subsampling=0
280-
)
281-
with open(output_path / f"{img_name}.json", "w") as f:
282-
f.write(json.dumps(json_store, indent=4))
271+
output_path = Path.cwd()
272+
disk_space_check(output_path, lim=5)
273+
for i in range(batch_size):
274+
json_store = {
275+
"prompt": args.prompts[i],
276+
"negative prompt": args.negative_prompts[i],
277+
"seed": args.seed,
278+
"variant": args.variant,
279+
"precision": args.precision,
280+
"steps": args.steps,
281+
"guidance_scale": args.guidance_scale,
282+
"scheduler": args.scheduler,
283+
}
284+
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[i][:15])
285+
img_name = f"{prompt_slice}_{args.seed}_{run}_{dt.now().strftime('%y%m%d_%H%M%S')}"
286+
pil_images[i].save(
287+
output_path / f"{img_name}.jpg", quality=95, subsampling=0
288+
)
289+
with open(output_path / f"{img_name}.json", "w") as f:
290+
f.write(json.dumps(json_store, indent=4))

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ def path_expand(s):
129129
default=None,
130130
help="Directory path to save the output images and json",
131131
)
132+
133+
p.add_argument(
134+
"--runs",
135+
type=int,
136+
default=1,
137+
help="number of images to be generated with random seeds in single execution",
138+
)
132139
##############################################################################
133140
### IREE - Vulkan supported flags
134141
##############################################################################

0 commit comments

Comments
 (0)