1- import os
21import traceback
32from dataclasses import dataclass
43from pathlib import Path
54from typing import Optional , Union
65
6+ import safetensors .torch
77import torch
88from picklescan .scanner import scan_file_path
99from transformers import CLIPTextModel , CLIPTokenizer
@@ -71,21 +71,6 @@ def load_textual_inversion(
7171
7272 if str (ckpt_path ).endswith (".DS_Store" ):
7373 return
74-
75- try :
76- scan_result = scan_file_path (str (ckpt_path ))
77- if scan_result .infected_files == 1 :
78- print (
79- f"\n ### Security Issues Found in Model: { scan_result .issues_count } "
80- )
81- print ("### For your safety, InvokeAI will not load this embed." )
82- return
83- except Exception :
84- print (
85- f"### { ckpt_path .parents [0 ].name } /{ ckpt_path .name } is damaged or corrupt."
86- )
87- return
88-
8974 embedding_info = self ._parse_embedding (str (ckpt_path ))
9075
9176 if embedding_info is None :
@@ -96,7 +81,7 @@ def load_textual_inversion(
9681 != embedding_info ["token_dim" ]
9782 ):
9883 print (
99- f"** Notice: { ckpt_path .parents [0 ].name } /{ ckpt_path .name } was trained on a model with an incompatible token dimension: { self .text_encoder .get_input_embeddings ().weight .data [0 ].shape [0 ]} vs { embedding_info ['token_dim' ]} ."
84+ f" ** Notice: { ckpt_path .parents [0 ].name } /{ ckpt_path .name } was trained on a model with an incompatible token dimension: { self .text_encoder .get_input_embeddings ().weight .data [0 ].shape [0 ]} vs { embedding_info ['token_dim' ]} ."
10085 )
10186 return
10287
@@ -309,92 +294,72 @@ def _get_or_create_token_id_and_assign_embedding(
309294
310295 return token_id
311296
312- def _parse_embedding (self , embedding_file : str ):
313- file_type = embedding_file .split ("." )[- 1 ]
314- if file_type == "pt" :
315- return self ._parse_embedding_pt (embedding_file )
316- elif file_type == "bin" :
317- return self ._parse_embedding_bin (embedding_file )
318- else :
319- print (f"** Notice: unrecognized embedding file format: { embedding_file } " )
297+ def _parse_embedding (self , embedding_file : str )-> dict :
298+ suffix = Path (embedding_file ).suffix
299+ try :
300+ if suffix in [".pt" ,".ckpt" ,".bin" ]:
301+ scan_result = scan_file_path (embedding_file )
302+ if scan_result .infected_files == 1 :
303+ print (
304+ f" ** Security Issues Found in Model: { scan_result .issues_count } "
305+ )
306+ print (" ** For your safety, InvokeAI will not load this embed." )
307+ return
308+ ckpt = torch .load (embedding_file ,map_location = "cpu" )
309+ else :
310+ ckpt = safetensors .torch .load_file (embedding_file )
311+ except Exception as e :
312+ print (f" ** Notice: unrecognized embedding file format: { embedding_file } : { e } " )
320313 return None
321-
322- def _parse_embedding_pt (self , embedding_file ):
323- embedding_ckpt = torch .load (embedding_file , map_location = "cpu" )
324- embedding_info = {}
325-
326- # Check if valid embedding file
327- if "string_to_token" and "string_to_param" in embedding_ckpt :
328- # Catch variants that do not have the expected keys or values.
329- try :
330- embedding_info ["name" ] = embedding_ckpt ["name" ] or os .path .basename (
331- os .path .splitext (embedding_file )[0 ]
332- )
333-
334- # Check num of embeddings and warn user only the first will be used
335- embedding_info ["num_of_embeddings" ] = len (
336- embedding_ckpt ["string_to_token" ]
337- )
338- if embedding_info ["num_of_embeddings" ] > 1 :
339- print (">> More than 1 embedding found. Will use the first one" )
340-
341- embedding = list (embedding_ckpt ["string_to_param" ].values ())[0 ]
342- except (AttributeError , KeyError ):
343- return self ._handle_broken_pt_variants (embedding_ckpt , embedding_file )
344-
345- embedding_info ["embedding" ] = embedding
346- embedding_info ["num_vectors_per_token" ] = embedding .size ()[0 ]
347- embedding_info ["token_dim" ] = embedding .size ()[1 ]
348-
349- try :
350- embedding_info ["trained_steps" ] = embedding_ckpt ["step" ]
351- embedding_info ["trained_model_name" ] = embedding_ckpt [
352- "sd_checkpoint_name"
353- ]
354- embedding_info ["trained_model_checksum" ] = embedding_ckpt [
355- "sd_checkpoint"
356- ]
357- except AttributeError :
358- print (">> No Training Details Found. Passing ..." )
359-
360- # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
361- # They are actually .bin files
362- elif len (embedding_ckpt .keys ()) == 1 :
363- embedding_info = self ._parse_embedding_bin (embedding_file )
364-
314+
315+ # try to figure out what kind of embedding file it is and parse accordingly
316+ keys = list (ckpt .keys ())
317+ if all (x in keys for x in ['string_to_token' ,'string_to_param' ,'name' ,'step' ]):
318+ return self ._parse_embedding_v1 (ckpt , embedding_file ) # example rem_rezero.pt
319+
320+ elif all (x in keys for x in ['string_to_token' ,'string_to_param' ]):
321+ return self ._parse_embedding_v2 (ckpt , embedding_file ) # example midj-strong.pt
322+
323+ elif 'emb_params' in keys :
324+ return self ._parse_embedding_v3 (ckpt , embedding_file ) # example easynegative.safetensors
325+
365326 else :
366- print (">> Invalid embedding format" )
367- embedding_info = None
327+ return self ._parse_embedding_v4 (ckpt , embedding_file ) # usually a '.bin' file
368328
369- return embedding_info
329+ def _parse_embedding_v1 (self , embedding_ckpt : dict , file_path : str ):
330+ basename = Path (file_path ).stem
331+ print (f' | Loading v1 embedding file: { basename } ' )
370332
371- def _parse_embedding_bin (self , embedding_file ):
372- embedding_ckpt = torch .load (embedding_file , map_location = "cpu" )
373333 embedding_info = {}
374-
375- if list (embedding_ckpt .keys ()) == 0 :
376- print (">> Invalid concepts file" )
377- embedding_info = None
378- else :
379- for token in list (embedding_ckpt .keys ()):
380- embedding_info ["name" ] = (
381- token
382- or f"<{ os .path .basename (os .path .splitext (embedding_file )[0 ])} >"
383- )
384- embedding_info ["embedding" ] = embedding_ckpt [token ]
385- embedding_info [
386- "num_vectors_per_token"
387- ] = 1 # All Concepts seem to default to 1
388- embedding_info ["token_dim" ] = embedding_info ["embedding" ].size ()[0 ]
389-
334+ embedding_info ["name" ] = embedding_ckpt ["name" ]
335+
336+ # Check num of embeddings and warn user only the first will be used
337+ embedding_info ["num_of_embeddings" ] = len (
338+ embedding_ckpt ["string_to_token" ]
339+ )
340+ if embedding_info ["num_of_embeddings" ] > 1 :
341+ print (" | More than 1 embedding found. Will use the first one" )
342+ embedding = list (embedding_ckpt ["string_to_param" ].values ())[0 ]
343+ embedding_info ["embedding" ] = embedding
344+ embedding_info ["num_vectors_per_token" ] = embedding .size ()[0 ]
345+ embedding_info ["token_dim" ] = embedding .size ()[1 ]
346+ embedding_info ["trained_steps" ] = embedding_ckpt ["step" ]
347+ embedding_info ["trained_model_name" ] = embedding_ckpt [
348+ "sd_checkpoint_name"
349+ ]
350+ embedding_info ["trained_model_checksum" ] = embedding_ckpt [
351+ "sd_checkpoint"
352+ ]
390353 return embedding_info
391354
392- def _handle_broken_pt_variants (
393- self , embedding_ckpt : dict , embedding_file : str
355+ def _parse_embedding_v2 (
356+ self , embedding_ckpt : dict , file_path : str
394357 ) -> dict :
395358 """
396- This handles the broken .pt file variants. We only know of one at present .
359+ This handles embedding .pt file variant #2 .
397360 """
361+ basename = Path (file_path ).stem
362+ print (f' | Loading v2 embedding file: { basename } ' )
398363 embedding_info = {}
399364 if isinstance (
400365 list (embedding_ckpt ["string_to_token" ].values ())[0 ], torch .Tensor
@@ -403,7 +368,7 @@ def _handle_broken_pt_variants(
403368 embedding_info ["name" ] = (
404369 token
405370 if token != "*"
406- else f"<{ os . path . basename ( os . path . splitext ( embedding_file )[ 0 ]) } >"
371+ else f"<{ basename } >"
407372 )
408373 embedding_info ["embedding" ] = embedding_ckpt [
409374 "string_to_param"
@@ -413,7 +378,46 @@ def _handle_broken_pt_variants(
413378 ].shape [0 ]
414379 embedding_info ["token_dim" ] = embedding_info ["embedding" ].size ()[1 ]
415380 else :
416- print (">> Invalid embedding format" )
381+ print (f" ** { basename } : Unrecognized embedding format" )
417382 embedding_info = None
418383
419384 return embedding_info
385+
386+ def _parse_embedding_v3 (self , embedding_ckpt : dict , file_path : str ):
387+ """
388+ Parse 'version 3' of the .pt textual inversion embedding files.
389+ """
390+ basename = Path (file_path ).stem
391+ print (f' | Loading v3 embedding file: { basename } ' )
392+ embedding_info = {}
393+ embedding_info ["name" ] = f'<{ basename } >'
394+ embedding_info ["num_of_embeddings" ] = 1
395+ embedding = embedding_ckpt ['emb_params' ]
396+ embedding_info ["embedding" ] = embedding
397+ embedding_info ["num_vectors_per_token" ] = embedding .size ()[0 ]
398+ embedding_info ["token_dim" ] = embedding .size ()[1 ]
399+ return embedding_info
400+
401+ def _parse_embedding_v4 (self , embedding_ckpt : dict , filepath : str ):
402+ """
403+ Parse 'version 4' of the textual inversion embedding files. This one
404+ is usually associated with .bin files trained by HuggingFace diffusers.
405+ """
406+ basename = Path (filepath ).stem
407+ short_path = Path (filepath ).parents [0 ].name + '/' + Path (filepath ).name
408+
409+ print (f' | Loading v4 embedding file: { short_path } ' )
410+ embedding_info = {}
411+ if list (embedding_ckpt .keys ()) == 0 :
412+ print (f" ** Invalid embeddings file: { short_path } " )
413+ embedding_info = None
414+ else :
415+ for token in list (embedding_ckpt .keys ()):
416+ embedding_info ["name" ] = (
417+ token
418+ or f"<{ basename } >"
419+ )
420+ embedding_info ["embedding" ] = embedding_ckpt [token ]
421+ embedding_info ["num_vectors_per_token" ] = 1 # All Concepts seem to default to 1
422+ embedding_info ["token_dim" ] = embedding_info ["embedding" ].size ()[0 ]
423+ return embedding_info
0 commit comments