Skip to content

Commit 8b3706f

Browse files
authored
Add Anything v3 and AnalogDiffusion variants of SD (huggingface#685)
* base support for anythingv3 * add analogdiffusiont * Update readme * keep max len 77 till support for 64 added for variants * lint fix
1 parent 0d51738 commit 8b3706f

File tree

5 files changed

+256
-111
lines changed

5 files changed

+256
-111
lines changed

shark/examples/shark_inference/stable_diffusion/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ unzip ~/.local/shark_tank/<your unet>/inputs.npz
4242

4343
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
4444
```
45+
46+
## Using other supported Stable Diffusion variants with SHARK:
47+
48+
Currently we support the following fine-tuned versions of Stable Diffusion:
49+
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
50+
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
51+
52+
use the flag `--variant=` to specify the model to be used.
53+
54+
```shell
55+
python .\shark\examples\shark_inference\stable_diffusion\main.py --variant=anythingv3 --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
56+
```

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def end_profiling(device):
5050
neg_prompt = args.negative_prompts
5151
height = 512 # default height of Stable Diffusion
5252
width = 512 # default width of Stable Diffusion
53-
if args.version == "v2.1":
53+
if args.version == "v2.1" and args.variant == "stablediffusion":
5454
height = 768
5555
width = 768
5656

@@ -71,9 +71,9 @@ def end_profiling(device):
7171
sys.exit("prompts and negative prompts must be of same length")
7272

7373
set_iree_runtime_flags()
74+
clip = get_clip()
7475
unet = get_unet()
7576
vae = get_vae()
76-
clip = get_clip()
7777
if args.dump_isa:
7878
dump_isas(args.dispatch_benchmarks_dir)
7979

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212

1313
# clip has 2 variants of max length 77 or 64.
1414
model_clip_max_length = 64 if args.max_length == 64 else 77
15+
if args.variant != "stablediffusion":
16+
model_clip_max_length = 77
17+
18+
model_variant = {
19+
"stablediffusion": "SD",
20+
"anythingv3": "Linaqruf/anything-v3.0",
21+
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
22+
"openjourney": "prompthero/openjourney",
23+
"analogdiffusion": "wavymulder/Analog-Diffusion",
24+
}
1525

1626
model_input = {
1727
"v2.1": {
@@ -47,18 +57,32 @@
4757
}
4858

4959
# revision param for from_pretrained defaults to "main" => fp32
50-
model_revision = "fp16" if args.precision == "fp16" else "main"
60+
model_revision = {
61+
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
62+
"anythingv3": "diffusers",
63+
"analogdiffusion": "main",
64+
}
5165

5266

5367
def get_clip_mlir(model_name="clip_text", extra_args=[]):
5468

5569
text_encoder = CLIPTextModel.from_pretrained(
5670
"openai/clip-vit-large-patch14"
5771
)
58-
if args.version != "v1.4":
72+
if args.variant == "stablediffusion":
73+
if args.version != "v1.4":
74+
text_encoder = CLIPTextModel.from_pretrained(
75+
model_config[args.version], subfolder="text_encoder"
76+
)
77+
78+
elif args.variant in ["anythingv3", "analogdiffusion"]:
5979
text_encoder = CLIPTextModel.from_pretrained(
60-
model_config[args.version], subfolder="text_encoder"
80+
model_variant[args.variant],
81+
subfolder="text_encoder",
82+
revision=model_revision[args.variant],
6183
)
84+
else:
85+
raise (f"{args.variant} not yet added")
6286

6387
class CLIPText(torch.nn.Module):
6488
def __init__(self):
@@ -83,9 +107,11 @@ class VaeModel(torch.nn.Module):
83107
def __init__(self):
84108
super().__init__()
85109
self.vae = AutoencoderKL.from_pretrained(
86-
model_config[args.version],
110+
model_config[args.version]
111+
if args.variant == "stablediffusion"
112+
else model_variant[args.variant],
87113
subfolder="vae",
88-
revision=model_revision,
114+
revision=model_revision[args.variant],
89115
)
90116

91117
def forward(self, input):
@@ -96,16 +122,27 @@ def forward(self, input):
96122
return x.round()
97123

98124
vae = VaeModel()
99-
if args.precision == "fp16":
100-
vae = vae.half().cuda()
101-
inputs = tuple(
102-
[
103-
inputs.half().cuda()
104-
for inputs in model_input[args.version]["vae"]
105-
]
106-
)
125+
if args.variant == "stablediffusion":
126+
if args.precision == "fp16":
127+
vae = vae.half().cuda()
128+
inputs = tuple(
129+
[
130+
inputs.half().cuda()
131+
for inputs in model_input[args.version]["vae"]
132+
]
133+
)
134+
else:
135+
inputs = model_input[args.version]["vae"]
136+
elif args.variant in ["anythingv3", "analogdiffusion"]:
137+
if args.precision == "fp16":
138+
vae = vae.half().cuda()
139+
inputs = tuple(
140+
[inputs.half().cuda() for inputs in model_input["v1.4"]["vae"]]
141+
)
142+
else:
143+
inputs = model_input["v1.4"]["vae"]
107144
else:
108-
inputs = model_input[args.version]["vae"]
145+
raise (f"{args.variant} not yet added")
109146

110147
shark_vae = compile_through_fx(
111148
vae,
@@ -121,9 +158,11 @@ class UnetModel(torch.nn.Module):
121158
def __init__(self):
122159
super().__init__()
123160
self.unet = UNet2DConditionModel.from_pretrained(
124-
model_config[args.version],
161+
model_config[args.version]
162+
if args.variant == "stablediffusion"
163+
else model_variant[args.variant],
125164
subfolder="unet",
126-
revision=model_revision,
165+
revision=model_revision[args.variant],
127166
)
128167
self.in_channels = self.unet.in_channels
129168
self.train(False)
@@ -141,16 +180,30 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
141180
return noise_pred
142181

143182
unet = UnetModel()
144-
if args.precision == "fp16":
145-
unet = unet.half().cuda()
146-
inputs = tuple(
147-
[
148-
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
149-
for inputs in model_input[args.version]["unet"]
150-
]
151-
)
183+
if args.variant == "stablediffusion":
184+
if args.precision == "fp16":
185+
unet = unet.half().cuda()
186+
inputs = tuple(
187+
[
188+
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
189+
for inputs in model_input[args.version]["unet"]
190+
]
191+
)
192+
else:
193+
inputs = model_input[args.version]["unet"]
194+
elif args.variant in ["anythingv3", "analogdiffusion"]:
195+
if args.precision == "fp16":
196+
unet = unet.half().cuda()
197+
inputs = tuple(
198+
[
199+
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
200+
for inputs in model_input["v1.4"]["unet"]
201+
]
202+
)
203+
else:
204+
inputs = model_input["v1.4"]["unet"]
152205
else:
153-
inputs = model_input[args.version]["unet"]
206+
raise (f"{args.variant} is not yet added")
154207
shark_unet = compile_through_fx(
155208
unet,
156209
inputs,

0 commit comments

Comments
 (0)