Skip to content

Commit 9434981

Browse files
authored
Add random seed generation for seed = -1 in cli (huggingface#689)
1 parent 8b3706f commit 9434981

File tree

1 file changed

+8
-1
lines changed
  • shark/examples/shark_inference/stable_diffusion

1 file changed

+8
-1
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from tqdm.auto import tqdm
1717
import numpy as np
18+
from random import randint
1819
from stable_args import args
1920
from utils import get_shark_model, set_iree_runtime_flags
2021
from opt_params import get_unet, get_vae, get_clip
@@ -59,8 +60,14 @@ def end_profiling(device):
5960
# Scale for classifier-free guidance
6061
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
6162

63+
# Handle out of range seeds.
64+
uint32_info = np.iinfo(np.uint32)
65+
uint32_min, uint32_max = uint32_info.min, uint32_info.max
66+
seed = args.seed
67+
if seed < uint32_min or seed >= uint32_max:
68+
seed = randint(uint32_min, uint32_max)
6269
generator = torch.manual_seed(
63-
args.seed
70+
seed
6471
) # Seed generator to create the inital latent noise
6572

6673
# TODO: Add support for batch_size > 1.

0 commit comments

Comments
 (0)