|
1 | 1 | import os |
2 | 2 | from shark.model_annotation import model_annotation, create_context |
3 | | -from shark.iree_utils._common import run_cmd, iree_target_map |
| 3 | +from shark.iree_utils._common import iree_target_map, run_cmd |
4 | 4 | from shark.shark_downloader import ( |
5 | 5 | download_model, |
6 | 6 | download_public_file, |
7 | 7 | WORKDIR, |
8 | 8 | ) |
9 | 9 | from shark.parser import shark_args |
10 | 10 | from stable_args import args |
11 | | -from opt_params import get_params |
12 | | -from utils import set_init_device_flags |
13 | 11 |
|
14 | 12 |
|
15 | | -set_init_device_flags() |
16 | 13 | device = ( |
17 | 14 | args.device if "://" not in args.device else args.device.split("://")[0] |
18 | 15 | ) |
19 | 16 |
|
20 | | -# Downloads the model (Unet or VAE fp16) from shark_tank |
21 | | -shark_args.local_tank_cache = args.local_tank_cache |
22 | | -bucket_key = f"{args.variant}/untuned" |
23 | | -if args.annotation_model == "unet": |
24 | | - model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned" |
25 | | -elif args.annotation_model == "vae": |
26 | | - is_base = "/base" if args.use_base_vae else "" |
27 | | - model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}" |
28 | | - |
29 | | -bucket, model_name, iree_flags = get_params( |
30 | | - bucket_key, model_key, args.annotation_model, "untuned", args.precision |
31 | | -) |
32 | | -mlir_model, func_name, inputs, golden_out = download_model( |
33 | | - model_name, |
34 | | - tank_url=bucket, |
35 | | - frontend="torch", |
36 | | -) |
37 | 17 |
|
38 | | -# Downloads the tuned config files from shark_tank |
39 | | -config_bucket = "gs://shark_tank/sd_tuned/configs/" |
40 | | -if args.use_winograd: |
| 18 | +# Download the model (Unet or VAE fp16) from shark_tank |
| 19 | +def load_model_from_tank(): |
| 20 | + from opt_params import get_params, version, variant |
| 21 | + |
| 22 | + shark_args.local_tank_cache = args.local_tank_cache |
| 23 | + bucket_key = f"{variant}/untuned" |
| 24 | + if args.annotation_model == "unet": |
| 25 | + model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned" |
| 26 | + elif args.annotation_model == "vae": |
| 27 | + is_base = "/base" if args.use_base_vae else "" |
| 28 | + model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}" |
| 29 | + |
| 30 | + bucket, model_name, iree_flags = get_params( |
| 31 | + bucket_key, model_key, args.annotation_model, "untuned", args.precision |
| 32 | + ) |
| 33 | + mlir_model, func_name, inputs, golden_out = download_model( |
| 34 | + model_name, |
| 35 | + tank_url=bucket, |
| 36 | + frontend="torch", |
| 37 | + ) |
| 38 | + return mlir_model, model_name |
| 39 | + |
| 40 | + |
| 41 | +# Download the tuned config files from shark_tank |
| 42 | +def load_winograd_configs(): |
| 43 | + config_bucket = "gs://shark_tank/sd_tuned/configs/" |
41 | 44 | config_name = f"{args.annotation_model}_winograd_{device}.json" |
42 | 45 | full_gs_url = config_bucket + config_name |
43 | 46 | winograd_config_dir = f"{WORKDIR}configs/" + config_name |
| 47 | + print("Loading Winograd config file from ", winograd_config_dir) |
44 | 48 | download_public_file(full_gs_url, winograd_config_dir, True) |
| 49 | + return winograd_config_dir |
45 | 50 |
|
46 | | -if args.annotation_model == "unet" or device == "cuda": |
47 | | - if args.variant in ["anythingv3", "analogdiffusion"]: |
| 51 | + |
| 52 | +def load_lower_configs(): |
| 53 | + from opt_params import version, variant |
| 54 | + |
| 55 | + config_bucket = "gs://shark_tank/sd_tuned/configs/" |
| 56 | + config_version = version |
| 57 | + if variant in ["anythingv3", "analogdiffusion"]: |
48 | 58 | args.max_length = 77 |
49 | | - args.version = "v1_4" |
| 59 | + config_version = "v1_4" |
50 | 60 | if args.annotation_model == "vae": |
51 | 61 | args.max_length = 77 |
52 | | - config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json" |
| 62 | + config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json" |
53 | 63 | full_gs_url = config_bucket + config_name |
54 | 64 | lowering_config_dir = f"{WORKDIR}configs/" + config_name |
| 65 | + print("Loading lowering config file from ", lowering_config_dir) |
55 | 66 | download_public_file(full_gs_url, lowering_config_dir, True) |
| 67 | + return lowering_config_dir |
| 68 | + |
56 | 69 |
|
57 | 70 | # Annotate the model with Winograd attribute on selected conv ops |
58 | | -if args.use_winograd: |
| 71 | +def annotate_with_winograd(input_mlir, winograd_config_dir, model_name): |
| 72 | + if model_name.split("_")[-1] != "tuned": |
| 73 | + out_file_path = ( |
| 74 | + f"{args.annotation_output}/{model_name}_tuned_torch.mlir" |
| 75 | + ) |
| 76 | + else: |
| 77 | + out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir" |
| 78 | + |
59 | 79 | with create_context() as ctx: |
60 | 80 | winograd_model = model_annotation( |
61 | 81 | ctx, |
62 | | - input_contents=mlir_model, |
| 82 | + input_contents=input_mlir, |
63 | 83 | config_path=winograd_config_dir, |
64 | 84 | search_op="conv", |
65 | | - winograd=args.use_winograd, |
| 85 | + winograd=True, |
66 | 86 | ) |
67 | | - with open( |
68 | | - f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w" |
69 | | - ) as f: |
| 87 | + with open(out_file_path, "w") as f: |
70 | 88 | f.write(str(winograd_model)) |
| 89 | + f.close() |
| 90 | + return winograd_model, out_file_path |
| 91 | + |
71 | 92 |
|
72 | 93 | # For Unet annotate the model with tuned lowering configs |
73 | | -if args.annotation_model == "unet" or device == "cuda": |
74 | | - if args.use_winograd: |
75 | | - input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir" |
| 94 | +def annotate_with_lower_configs( |
| 95 | + input_mlir, lowering_config_dir, model_name, use_winograd |
| 96 | +): |
| 97 | + if use_winograd: |
76 | 98 | dump_after = "iree-linalg-ext-convert-conv2d-to-winograd" |
77 | 99 | else: |
78 | | - input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir" |
79 | 100 | dump_after = "iree-flow-pad-linalg-ops" |
80 | 101 |
|
81 | 102 | # Dump IR after padding/img2col/winograd passes |
|
90 | 111 | device_spec_args = ( |
91 | 112 | f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} " |
92 | 113 | ) |
| 114 | + print("Applying tuned configs on", model_name) |
| 115 | + |
93 | 116 | run_cmd( |
94 | 117 | f"iree-compile {input_mlir} " |
95 | 118 | "--iree-input-type=tm_tensor " |
|
116 | 139 |
|
117 | 140 | # Remove the intermediate mlir and save the final annotated model |
118 | 141 | os.remove(f"{args.annotation_output}/dump_after_winograd.mlir") |
119 | | - output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir" |
120 | | - with open(output_path, "w") as f: |
| 142 | + if model_name.split("_")[-1] != "tuned": |
| 143 | + out_file_path = ( |
| 144 | + f"{args.annotation_output}/{model_name}_tuned_torch.mlir" |
| 145 | + ) |
| 146 | + else: |
| 147 | + out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir" |
| 148 | + with open(out_file_path, "w") as f: |
121 | 149 | f.write(str(tuned_model)) |
| 150 | + f.close() |
| 151 | + return tuned_model, out_file_path |
| 152 | + |
| 153 | + |
| 154 | +def sd_model_annotation(mlir_model, model_name): |
| 155 | + if args.annotation_model == "unet" and device == "vulkan": |
| 156 | + use_winograd = True |
| 157 | + winograd_config_dir = load_winograd_configs() |
| 158 | + winograd_model, model_path = annotate_with_winograd( |
| 159 | + mlir_model, winograd_config_dir, model_name |
| 160 | + ) |
| 161 | + lowering_config_dir = load_lower_configs() |
| 162 | + tuned_model, output_path = annotate_with_lower_configs( |
| 163 | + model_path, lowering_config_dir, model_name, use_winograd |
| 164 | + ) |
| 165 | + elif args.annotation_model == "vae" and device == "vulkan": |
| 166 | + use_winograd = True |
| 167 | + winograd_config_dir = load_winograd_configs() |
| 168 | + tuned_model, output_path = annotate_with_winograd( |
| 169 | + mlir_model, winograd_config_dir, model_name |
| 170 | + ) |
| 171 | + else: |
| 172 | + use_winograd = False |
| 173 | + lowering_config_dir = load_lower_configs() |
| 174 | + tuned_model, output_path = annotate_with_lower_configs( |
| 175 | + mlir_model, lowering_config_dir, model_name, use_winograd |
| 176 | + ) |
122 | 177 | print(f"Saved the annotated mlir in {output_path}.") |
| 178 | + return tuned_model, output_path |
| 179 | + |
| 180 | + |
| 181 | +if __name__ == "__main__": |
| 182 | + mlir_model, model_name = load_model_from_tank() |
| 183 | + if device == "cuda": |
| 184 | + mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir" |
| 185 | + |
| 186 | + sd_model_annotation(mlir_model, model_name) |
0 commit comments