Skip to content

Commit 5aade31

Browse files
Pad conditionings using zeros and encoder_attention_mask (#3772)
2 parents 565299c + 7093e5d commit 5aade31

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

invokeai/app/invocations/compel.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _lora_loader():
100100
text_encoder=text_encoder,
101101
textual_inversion_manager=ti_manager,
102102
dtype_for_device_getter=torch_dtype,
103-
truncate_long_prompts=True, # TODO:
103+
truncate_long_prompts=False,
104104
)
105105

106106
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -112,9 +112,6 @@ def _lora_loader():
112112
c, options = compel.build_conditioning_tensor_for_prompt_object(
113113
prompt)
114114

115-
# TODO: long prompt support
116-
# if not self.truncate_long_prompts:
117-
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
118115
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
119116
tokens_count_including_eos_bos=get_max_token_count(
120117
tokenizer, conjunction),

invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,45 @@ def do_latent_postprocessing(
241241

242242
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
243243
# fast batched path
244+
245+
def _pad_conditioning(cond, target_len, encoder_attention_mask):
246+
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
247+
248+
if cond.shape[1] < max_len:
249+
conditioning_attention_mask = torch.cat([
250+
conditioning_attention_mask,
251+
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
252+
], dim=1)
253+
254+
cond = torch.cat([
255+
cond,
256+
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
257+
], dim=1)
258+
259+
if encoder_attention_mask is None:
260+
encoder_attention_mask = conditioning_attention_mask
261+
else:
262+
encoder_attention_mask = torch.cat([
263+
encoder_attention_mask,
264+
conditioning_attention_mask,
265+
])
266+
267+
return cond, encoder_attention_mask
268+
244269
x_twice = torch.cat([x] * 2)
245270
sigma_twice = torch.cat([sigma] * 2)
271+
272+
encoder_attention_mask = None
273+
if unconditioning.shape[1] != conditioning.shape[1]:
274+
max_len = max(unconditioning.shape[1], conditioning.shape[1])
275+
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
276+
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
277+
246278
both_conditionings = torch.cat([unconditioning, conditioning])
247279
both_results = self.model_forward_callback(
248-
x_twice, sigma_twice, both_conditionings, **kwargs,
280+
x_twice, sigma_twice, both_conditionings,
281+
encoder_attention_mask=encoder_attention_mask,
282+
**kwargs,
249283
)
250284
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
251285
return unconditioned_next_x, conditioned_next_x

0 commit comments

Comments
 (0)