2121UNET_FP32 = "unet_fp32"
2222IREE_EXTRA_ARGS = []
2323
24+ TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
25+ UNET_FP16_TUNED = "unet_fp16_tuned"
26+
2427# Helper function to profile the vulkan device.
2528def start_profiling (file_path = "foo.rdc" , profiling_mode = "queue" ):
2629 if args .vulkan_debug_utils and "vulkan" in args .device :
@@ -42,24 +45,41 @@ def get_models():
4245 global IREE_EXTRA_ARGS
4346 if args .precision == "fp16" :
4447 IREE_EXTRA_ARGS += [
45- "--iree-flow-enable-conv-nchw-to-nhwc-transform" ,
4648 "--iree-flow-enable-padding-linalg-ops" ,
4749 "--iree-flow-linalg-ops-padding-size=32" ,
48- "--iree-spirv-unify-aliased-resources=false" ,
4950 ]
51+ if args .use_tuned :
52+ unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
53+ vae_gcloud_bucket = GCLOUD_BUCKET
54+ unet_args = IREE_EXTRA_ARGS
55+ vae_args = IREE_EXTRA_ARGS + [
56+ "--iree-flow-enable-conv-nchw-to-nhwc-transform"
57+ ]
58+ unet_name = UNET_FP16_TUNED
59+ vae_name = VAE_FP16
60+ else :
61+ unet_gcloud_bucket = GCLOUD_BUCKET
62+ vae_gcloud_bucket = GCLOUD_BUCKET
63+ IREE_EXTRA_ARGS += [
64+ "--iree-flow-enable-conv-nchw-to-nhwc-transform"
65+ ]
66+ unet_args = IREE_EXTRA_ARGS
67+ vae_args = IREE_EXTRA_ARGS
68+ unet_name = UNET_FP16
69+ vae_name = VAE_FP16
5070 if args .import_mlir == True :
5171 return get_vae16 (model_name = VAE_FP16 ), get_unet16_wrapped (
5272 model_name = UNET_FP16
5373 )
5474 else :
5575 return get_shark_model (
56- GCLOUD_BUCKET ,
57- VAE_FP16 ,
58- IREE_EXTRA_ARGS ,
76+ vae_gcloud_bucket ,
77+ vae_name ,
78+ vae_args ,
5979 ), get_shark_model (
60- GCLOUD_BUCKET ,
61- UNET_FP16 ,
62- IREE_EXTRA_ARGS ,
80+ unet_gcloud_bucket ,
81+ unet_name ,
82+ unet_args ,
6383 )
6484
6585 elif args .precision == "fp32" :
0 commit comments