Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 99 additions & 95 deletions ldm/modules/textual_inversion_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -71,21 +71,6 @@ def load_textual_inversion(

if str(ckpt_path).endswith(".DS_Store"):
return

try:
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
)
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return

embedding_info = self._parse_embedding(str(ckpt_path))

if embedding_info is None:
Expand All @@ -96,7 +81,7 @@ def load_textual_inversion(
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return

Expand Down Expand Up @@ -309,92 +294,72 @@ def _get_or_create_token_id_and_assign_embedding(

return token_id

def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
def _parse_embedding(self, embedding_file: str)->dict:
suffix = Path(embedding_file).suffix
try:
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files == 1:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
return
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return None

def _parse_embedding_pt(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}

# Check if valid embedding file
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
os.path.splitext(embedding_file)[0]
)

# Check num of embeddings and warn user only the first will be used
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")

embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)

embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]

try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")

# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)


# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt

elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt

elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors

else:
print(">> Invalid embedding format")
embedding_info = None
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file

return embedding_info
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str):
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')

def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}

if list(embedding_ckpt.keys()) == 0:
print(">> Invalid concepts file")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]

embedding_info["name"] = embedding_ckpt["name"]

# Check num of embeddings and warn user only the first will be used
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(" | More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0]
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
return embedding_info

def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> dict:
"""
This handles the broken .pt file variants. We only know of one at present.
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embedding_info = {}
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
Expand All @@ -403,7 +368,7 @@ def _handle_broken_pt_variants(
embedding_info["name"] = (
token
if token != "*"
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
else f"<{basename}>"
)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"
Expand All @@ -413,7 +378,46 @@ def _handle_broken_pt_variants(
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else:
print(">> Invalid embedding format")
print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None

return embedding_info

def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str):
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding_info = {}
embedding_info["name"] = f'<{basename}>'
embedding_info["num_of_embeddings"] = 1
embedding = embedding_ckpt['emb_params']
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
return embedding_info

def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str):
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name

print(f' | Loading v4 embedding file: {short_path}')
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{basename}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info