Skip to content

Commit 0d196f9

Browse files
authored
Fix issue in maybe_convert_prompt (#3188)
When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens. Adding a space for the padding tokens fixes this.
1 parent 131312c commit 0d196f9

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
410410
replacement = token
411411
i = 1
412412
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
413-
replacement += f"{token}_{i}"
413+
replacement += f" {token}_{i}"
414414
i += 1
415415

416416
prompt = prompt.replace(token, replacement)

tests/pipelines/test_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def test_text_inversion_download(self):
541541
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
542542
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
543543
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
544-
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2"
544+
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***> <***>_1 <***>_2"
545545

546546
prompt = "hey <***>"
547547
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
@@ -569,7 +569,7 @@ def test_text_inversion_download(self):
569569
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
570570
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
571571
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
572-
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2"
572+
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****> <****>_1 <****>_2"
573573

574574
prompt = "hey <****>"
575575
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images

0 commit comments

Comments
 (0)