Skip to content

Commit dd22c65

Browse files
authored
Add CUDA tuned models for SD variants (huggingface#814)
1 parent 48137ce commit dd22c65

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

shark/examples/shark_inference/stable_diffusion/resources/model_db.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,28 @@
55
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
66
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
77
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
8+
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
89
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
910
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
11+
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
1012
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
1113
"openjourney/tuned":"gs://shark_tank/sd_tuned",
1214
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
1315
},
1416
{
1517
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
1618
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
19+
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
1720
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
1821
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
1922
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
23+
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
2024
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
2125
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
2226
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
2327
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
2428
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
29+
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
2530
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
2631
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
2732
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
@@ -39,18 +44,22 @@
3944
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
4045
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
4146
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
47+
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
4248
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
4349
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
4450
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
51+
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
4552
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
4653
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
4754
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
4855
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
4956
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
5057
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
58+
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
5159
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
5260
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
5361
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
62+
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
5463
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
5564
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
5665
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",

shark/examples/shark_inference/stable_diffusion/sd_annotation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
download_public_file(full_gs_url, winograd_config_dir, True)
4545

4646
if args.annotation_model == "unet" or device == "cuda":
47-
if (
48-
args.variant in ["anythingv3", "analogdiffusion"]
49-
or args.annotation_model == "vae"
50-
):
47+
if args.variant in ["anythingv3", "analogdiffusion"]:
48+
args.max_length = 77
49+
args.version = "v1_4"
50+
if args.annotation_model == "vae":
5151
args.max_length = 77
5252
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
5353
full_gs_url = config_bucket + config_name

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,10 @@ def set_init_device_flags():
205205

206206
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
207207
if (
208-
args.variant == "stablediffusion"
208+
args.variant in ["stablediffusion", "anythingv3", "analogdiffusion"]
209209
and args.precision == "fp16"
210210
and "cuda" in args.device
211211
and get_cuda_sm_cc() == "sm_80"
212-
and args.version == "v2_1base"
213212
):
214213
args.use_tuned = True
215214

0 commit comments

Comments
 (0)