1212
1313# clip has 2 variants of max length 77 or 64.
1414model_clip_max_length = 64 if args .max_length == 64 else 77
15+ if args .variant != "stablediffusion" :
16+ model_clip_max_length = 77
17+
18+ model_variant = {
19+ "stablediffusion" : "SD" ,
20+ "anythingv3" : "Linaqruf/anything-v3.0" ,
21+ "dreamlike" : "dreamlike-art/dreamlike-diffusion-1.0" ,
22+ "openjourney" : "prompthero/openjourney" ,
23+ "analogdiffusion" : "wavymulder/Analog-Diffusion" ,
24+ }
1525
1626model_input = {
1727 "v2.1" : {
4757}
4858
4959# revision param for from_pretrained defaults to "main" => fp32
50- model_revision = "fp16" if args .precision == "fp16" else "main"
60+ model_revision = {
61+ "stablediffusion" : "fp16" if args .precision == "fp16" else "main" ,
62+ "anythingv3" : "diffusers" ,
63+ "analogdiffusion" : "main" ,
64+ }
5165
5266
5367def get_clip_mlir (model_name = "clip_text" , extra_args = []):
5468
5569 text_encoder = CLIPTextModel .from_pretrained (
5670 "openai/clip-vit-large-patch14"
5771 )
58- if args .version != "v1.4" :
72+ if args .variant == "stablediffusion" :
73+ if args .version != "v1.4" :
74+ text_encoder = CLIPTextModel .from_pretrained (
75+ model_config [args .version ], subfolder = "text_encoder"
76+ )
77+
78+ elif args .variant in ["anythingv3" , "analogdiffusion" ]:
5979 text_encoder = CLIPTextModel .from_pretrained (
60- model_config [args .version ], subfolder = "text_encoder"
80+ model_variant [args .variant ],
81+ subfolder = "text_encoder" ,
82+ revision = model_revision [args .variant ],
6183 )
84+ else :
85+ raise (f"{ args .variant } not yet added" )
6286
6387 class CLIPText (torch .nn .Module ):
6488 def __init__ (self ):
@@ -83,9 +107,11 @@ class VaeModel(torch.nn.Module):
83107 def __init__ (self ):
84108 super ().__init__ ()
85109 self .vae = AutoencoderKL .from_pretrained (
86- model_config [args .version ],
110+ model_config [args .version ]
111+ if args .variant == "stablediffusion"
112+ else model_variant [args .variant ],
87113 subfolder = "vae" ,
88- revision = model_revision ,
114+ revision = model_revision [ args . variant ] ,
89115 )
90116
91117 def forward (self , input ):
@@ -96,16 +122,27 @@ def forward(self, input):
96122 return x .round ()
97123
98124 vae = VaeModel ()
99- if args .precision == "fp16" :
100- vae = vae .half ().cuda ()
101- inputs = tuple (
102- [
103- inputs .half ().cuda ()
104- for inputs in model_input [args .version ]["vae" ]
105- ]
106- )
125+ if args .variant == "stablediffusion" :
126+ if args .precision == "fp16" :
127+ vae = vae .half ().cuda ()
128+ inputs = tuple (
129+ [
130+ inputs .half ().cuda ()
131+ for inputs in model_input [args .version ]["vae" ]
132+ ]
133+ )
134+ else :
135+ inputs = model_input [args .version ]["vae" ]
136+ elif args .variant in ["anythingv3" , "analogdiffusion" ]:
137+ if args .precision == "fp16" :
138+ vae = vae .half ().cuda ()
139+ inputs = tuple (
140+ [inputs .half ().cuda () for inputs in model_input ["v1.4" ]["vae" ]]
141+ )
142+ else :
143+ inputs = model_input ["v1.4" ]["vae" ]
107144 else :
108- inputs = model_input [ args .version ][ "vae" ]
145+ raise ( f" { args .variant } not yet added" )
109146
110147 shark_vae = compile_through_fx (
111148 vae ,
@@ -121,9 +158,11 @@ class UnetModel(torch.nn.Module):
121158 def __init__ (self ):
122159 super ().__init__ ()
123160 self .unet = UNet2DConditionModel .from_pretrained (
124- model_config [args .version ],
161+ model_config [args .version ]
162+ if args .variant == "stablediffusion"
163+ else model_variant [args .variant ],
125164 subfolder = "unet" ,
126- revision = model_revision ,
165+ revision = model_revision [ args . variant ] ,
127166 )
128167 self .in_channels = self .unet .in_channels
129168 self .train (False )
@@ -141,16 +180,30 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
141180 return noise_pred
142181
143182 unet = UnetModel ()
144- if args .precision == "fp16" :
145- unet = unet .half ().cuda ()
146- inputs = tuple (
147- [
148- inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
149- for inputs in model_input [args .version ]["unet" ]
150- ]
151- )
183+ if args .variant == "stablediffusion" :
184+ if args .precision == "fp16" :
185+ unet = unet .half ().cuda ()
186+ inputs = tuple (
187+ [
188+ inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
189+ for inputs in model_input [args .version ]["unet" ]
190+ ]
191+ )
192+ else :
193+ inputs = model_input [args .version ]["unet" ]
194+ elif args .variant in ["anythingv3" , "analogdiffusion" ]:
195+ if args .precision == "fp16" :
196+ unet = unet .half ().cuda ()
197+ inputs = tuple (
198+ [
199+ inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
200+ for inputs in model_input ["v1.4" ]["unet" ]
201+ ]
202+ )
203+ else :
204+ inputs = model_input ["v1.4" ]["unet" ]
152205 else :
153- inputs = model_input [ args .version ][ "unet" ]
206+ raise ( f" { args .variant } is not yet added" )
154207 shark_unet = compile_through_fx (
155208 unet ,
156209 inputs ,
0 commit comments