Skip to content

Commit bb972b2

Browse files
authored
Add support for yet another TI embedding file format (2.3 version) (#3045)
- This variant, exemplified by "easynegative.safetensors" has a single 'embparam' key containing a Tensor. - Also refactored code to make it easier to read. - Handle both pickle and safetensor formats.
2 parents 41a8fde + a78ff86 commit bb972b2

File tree

1 file changed

+99
-95
lines changed

1 file changed

+99
-95
lines changed

ldm/modules/textual_inversion_manager.py

Lines changed: 99 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import os
21
import traceback
32
from dataclasses import dataclass
43
from pathlib import Path
54
from typing import Optional, Union
65

6+
import safetensors.torch
77
import torch
88
from picklescan.scanner import scan_file_path
99
from transformers import CLIPTextModel, CLIPTokenizer
@@ -71,21 +71,6 @@ def load_textual_inversion(
7171

7272
if str(ckpt_path).endswith(".DS_Store"):
7373
return
74-
75-
try:
76-
scan_result = scan_file_path(str(ckpt_path))
77-
if scan_result.infected_files == 1:
78-
print(
79-
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
80-
)
81-
print("### For your safety, InvokeAI will not load this embed.")
82-
return
83-
except Exception:
84-
print(
85-
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
86-
)
87-
return
88-
8974
embedding_info = self._parse_embedding(str(ckpt_path))
9075

9176
if embedding_info is None:
@@ -96,7 +81,7 @@ def load_textual_inversion(
9681
!= embedding_info["token_dim"]
9782
):
9883
print(
99-
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
84+
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
10085
)
10186
return
10287

@@ -309,92 +294,72 @@ def _get_or_create_token_id_and_assign_embedding(
309294

310295
return token_id
311296

312-
def _parse_embedding(self, embedding_file: str):
313-
file_type = embedding_file.split(".")[-1]
314-
if file_type == "pt":
315-
return self._parse_embedding_pt(embedding_file)
316-
elif file_type == "bin":
317-
return self._parse_embedding_bin(embedding_file)
318-
else:
319-
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
297+
def _parse_embedding(self, embedding_file: str)->dict:
298+
suffix = Path(embedding_file).suffix
299+
try:
300+
if suffix in [".pt",".ckpt",".bin"]:
301+
scan_result = scan_file_path(embedding_file)
302+
if scan_result.infected_files == 1:
303+
print(
304+
f" ** Security Issues Found in Model: {scan_result.issues_count}"
305+
)
306+
print(" ** For your safety, InvokeAI will not load this embed.")
307+
return
308+
ckpt = torch.load(embedding_file,map_location="cpu")
309+
else:
310+
ckpt = safetensors.torch.load_file(embedding_file)
311+
except Exception as e:
312+
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
320313
return None
321-
322-
def _parse_embedding_pt(self, embedding_file):
323-
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
324-
embedding_info = {}
325-
326-
# Check if valid embedding file
327-
if "string_to_token" and "string_to_param" in embedding_ckpt:
328-
# Catch variants that do not have the expected keys or values.
329-
try:
330-
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
331-
os.path.splitext(embedding_file)[0]
332-
)
333-
334-
# Check num of embeddings and warn user only the first will be used
335-
embedding_info["num_of_embeddings"] = len(
336-
embedding_ckpt["string_to_token"]
337-
)
338-
if embedding_info["num_of_embeddings"] > 1:
339-
print(">> More than 1 embedding found. Will use the first one")
340-
341-
embedding = list(embedding_ckpt["string_to_param"].values())[0]
342-
except (AttributeError, KeyError):
343-
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
344-
345-
embedding_info["embedding"] = embedding
346-
embedding_info["num_vectors_per_token"] = embedding.size()[0]
347-
embedding_info["token_dim"] = embedding.size()[1]
348-
349-
try:
350-
embedding_info["trained_steps"] = embedding_ckpt["step"]
351-
embedding_info["trained_model_name"] = embedding_ckpt[
352-
"sd_checkpoint_name"
353-
]
354-
embedding_info["trained_model_checksum"] = embedding_ckpt[
355-
"sd_checkpoint"
356-
]
357-
except AttributeError:
358-
print(">> No Training Details Found. Passing ...")
359-
360-
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
361-
# They are actually .bin files
362-
elif len(embedding_ckpt.keys()) == 1:
363-
embedding_info = self._parse_embedding_bin(embedding_file)
364-
314+
315+
# try to figure out what kind of embedding file it is and parse accordingly
316+
keys = list(ckpt.keys())
317+
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
318+
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
319+
320+
elif all(x in keys for x in ['string_to_token','string_to_param']):
321+
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
322+
323+
elif 'emb_params' in keys:
324+
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
325+
365326
else:
366-
print(">> Invalid embedding format")
367-
embedding_info = None
327+
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
368328

369-
return embedding_info
329+
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str):
330+
basename = Path(file_path).stem
331+
print(f' | Loading v1 embedding file: {basename}')
370332

371-
def _parse_embedding_bin(self, embedding_file):
372-
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
373333
embedding_info = {}
374-
375-
if list(embedding_ckpt.keys()) == 0:
376-
print(">> Invalid concepts file")
377-
embedding_info = None
378-
else:
379-
for token in list(embedding_ckpt.keys()):
380-
embedding_info["name"] = (
381-
token
382-
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
383-
)
384-
embedding_info["embedding"] = embedding_ckpt[token]
385-
embedding_info[
386-
"num_vectors_per_token"
387-
] = 1 # All Concepts seem to default to 1
388-
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
389-
334+
embedding_info["name"] = embedding_ckpt["name"]
335+
336+
# Check num of embeddings and warn user only the first will be used
337+
embedding_info["num_of_embeddings"] = len(
338+
embedding_ckpt["string_to_token"]
339+
)
340+
if embedding_info["num_of_embeddings"] > 1:
341+
print(" | More than 1 embedding found. Will use the first one")
342+
embedding = list(embedding_ckpt["string_to_param"].values())[0]
343+
embedding_info["embedding"] = embedding
344+
embedding_info["num_vectors_per_token"] = embedding.size()[0]
345+
embedding_info["token_dim"] = embedding.size()[1]
346+
embedding_info["trained_steps"] = embedding_ckpt["step"]
347+
embedding_info["trained_model_name"] = embedding_ckpt[
348+
"sd_checkpoint_name"
349+
]
350+
embedding_info["trained_model_checksum"] = embedding_ckpt[
351+
"sd_checkpoint"
352+
]
390353
return embedding_info
391354

392-
def _handle_broken_pt_variants(
393-
self, embedding_ckpt: dict, embedding_file: str
355+
def _parse_embedding_v2 (
356+
self, embedding_ckpt: dict, file_path: str
394357
) -> dict:
395358
"""
396-
This handles the broken .pt file variants. We only know of one at present.
359+
This handles embedding .pt file variant #2.
397360
"""
361+
basename = Path(file_path).stem
362+
print(f' | Loading v2 embedding file: {basename}')
398363
embedding_info = {}
399364
if isinstance(
400365
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
@@ -403,7 +368,7 @@ def _handle_broken_pt_variants(
403368
embedding_info["name"] = (
404369
token
405370
if token != "*"
406-
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
371+
else f"<{basename}>"
407372
)
408373
embedding_info["embedding"] = embedding_ckpt[
409374
"string_to_param"
@@ -413,7 +378,46 @@ def _handle_broken_pt_variants(
413378
].shape[0]
414379
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
415380
else:
416-
print(">> Invalid embedding format")
381+
print(f" ** {basename}: Unrecognized embedding format")
417382
embedding_info = None
418383

419384
return embedding_info
385+
386+
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str):
387+
"""
388+
Parse 'version 3' of the .pt textual inversion embedding files.
389+
"""
390+
basename = Path(file_path).stem
391+
print(f' | Loading v3 embedding file: {basename}')
392+
embedding_info = {}
393+
embedding_info["name"] = f'<{basename}>'
394+
embedding_info["num_of_embeddings"] = 1
395+
embedding = embedding_ckpt['emb_params']
396+
embedding_info["embedding"] = embedding
397+
embedding_info["num_vectors_per_token"] = embedding.size()[0]
398+
embedding_info["token_dim"] = embedding.size()[1]
399+
return embedding_info
400+
401+
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str):
402+
"""
403+
Parse 'version 4' of the textual inversion embedding files. This one
404+
is usually associated with .bin files trained by HuggingFace diffusers.
405+
"""
406+
basename = Path(filepath).stem
407+
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
408+
409+
print(f' | Loading v4 embedding file: {short_path}')
410+
embedding_info = {}
411+
if list(embedding_ckpt.keys()) == 0:
412+
print(f" ** Invalid embeddings file: {short_path}")
413+
embedding_info = None
414+
else:
415+
for token in list(embedding_ckpt.keys()):
416+
embedding_info["name"] = (
417+
token
418+
or f"<{basename}>"
419+
)
420+
embedding_info["embedding"] = embedding_ckpt[token]
421+
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1
422+
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
423+
return embedding_info

0 commit comments

Comments
 (0)