@@ -75,11 +75,69 @@ def __init__(
7575 scheduler = scheduler ,
7676 )
7777
78+ def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance ):
79+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
80+
81+ # get prompt text embeddings
82+ text_inputs = self .tokenizer (
83+ prompt ,
84+ padding = "max_length" ,
85+ max_length = self .tokenizer .model_max_length ,
86+ return_tensors = "pt" ,
87+ )
88+ text_input_ids = text_inputs .input_ids
89+
90+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
91+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
92+ logger .warning (
93+ "The following part of your input was truncated because CLIP can only handle sequences up to"
94+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
95+ )
96+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
97+ text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
98+
99+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
100+ # While CLIP does normalize the pooled output of the text transformer when combining
101+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
102+ #
103+ # CLIP normalizing the pooled output.
104+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
105+ text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
106+
107+ # duplicate text embeddings for each generation per prompt
108+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
109+
110+ if do_classifier_free_guidance :
111+ uncond_tokens = ["" ] * batch_size
112+
113+ max_length = text_input_ids .shape [- 1 ]
114+ uncond_input = self .tokenizer (
115+ uncond_tokens ,
116+ padding = "max_length" ,
117+ max_length = max_length ,
118+ truncation = True ,
119+ return_tensors = "pt" ,
120+ )
121+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
122+
123+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
124+ seq_len = uncond_embeddings .shape [1 ]
125+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
126+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
127+
128+ # For classifier free guidance, we need to do two forward passes.
129+ # Here we concatenate the unconditional and text embeddings into a single batch
130+ # to avoid doing two forward passes
131+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
132+
133+ return text_embeddings
134+
78135 @torch .no_grad ()
79136 def __call__ (
80137 self ,
81138 prompt : Union [str , List [str ]],
82139 num_inference_steps : int = 100 ,
140+ guidance_scale : float = 5.0 ,
83141 truncation_rate : float = 1.0 ,
84142 num_images_per_prompt : int = 1 ,
85143 generator : Optional [torch .Generator ] = None ,
@@ -137,6 +195,12 @@ def __call__(
137195
138196 batch_size = batch_size * num_images_per_prompt
139197
198+ do_classifier_free_guidance = guidance_scale > 1.0
199+
200+ text_embeddings = self ._encode_prompt (
201+ prompt , num_images_per_prompt , do_classifier_free_guidance
202+ )
203+
140204 if (callback_steps is None ) or (
141205 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
142206 ):
@@ -145,35 +209,6 @@ def __call__(
145209 f" { type (callback_steps )} ."
146210 )
147211
148- # get prompt text embeddings
149- text_inputs = self .tokenizer (
150- prompt ,
151- padding = "max_length" ,
152- max_length = self .tokenizer .model_max_length ,
153- return_tensors = "pt" ,
154- )
155- text_input_ids = text_inputs .input_ids
156-
157- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
158- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
159- logger .warning (
160- "The following part of your input was truncated because CLIP can only handle sequences up to"
161- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
162- )
163- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
164- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
165-
166- # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167- # While CLIP does normalize the pooled output of the text transformer when combining
168- # the image and text embeddings, CLIP does not directly normalize the last hidden state.
169- #
170- # CLIP normalizing the pooled output.
171- # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172- text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
173-
174- # duplicate text embeddings for each generation per prompt
175- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
176-
177212 # get the initial completely masked latents unless the user supplied it
178213
179214 latents_shape = (batch_size , self .transformer .num_latent_pixels )
@@ -198,9 +233,17 @@ def __call__(
198233 sample = latents
199234
200235 for i , t in enumerate (self .progress_bar (timesteps_tensor )):
236+ # expand the sample if we are doing classifier free guidance
237+ latent_model_input = torch .cat ([sample ] * 2 ) if do_classifier_free_guidance else sample
238+
201239 # predict the un-noised image
202240 # model_output == `log_p_x_0`
203- model_output = self .transformer (sample , encoder_hidden_states = text_embeddings , timestep = t ).sample
241+ model_output = self .transformer (latent_model_input , encoder_hidden_states = text_embeddings , timestep = t ).sample
242+
243+ if do_classifier_free_guidance :
244+ model_output_uncond , model_output_text = model_output .chunk (2 )
245+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond )
246+ model_output -= torch .logsumexp (model_output , dim = 1 , keepdim = True )
204247
205248 model_output = self .truncate (model_output , truncation_rate )
206249
0 commit comments