66import coremltools as ct
77
88import logging
9+ import json
910
1011logging .basicConfig ()
1112logger = logging .getLogger (__name__ )
@@ -21,14 +22,47 @@ class CoreMLModel:
2122 """ Wrapper for running CoreML models using coremltools
2223 """
2324
24- def __init__ (self , model_path , compute_unit ):
25- assert os .path .exists (model_path ) and model_path .endswith (".mlpackage" )
25+ def __init__ (self , model_path , compute_unit , sources = 'packages' ):
2626
2727 logger .info (f"Loading { model_path } " )
2828
2929 start = time .time ()
30- self .model = ct .models .MLModel (
31- model_path , compute_units = ct .ComputeUnit [compute_unit ])
30+ if sources == 'packages' :
31+ assert os .path .exists (model_path ) and model_path .endswith (".mlpackage" )
32+
33+ self .model = ct .models .MLModel (
34+ model_path , compute_units = ct .ComputeUnit [compute_unit ])
35+ DTYPE_MAP = {
36+ 65552 : np .float16 ,
37+ 65568 : np .float32 ,
38+ 131104 : np .int32 ,
39+ }
40+ self .expected_inputs = {
41+ input_tensor .name : {
42+ "shape" : tuple (input_tensor .type .multiArrayType .shape ),
43+ "dtype" : DTYPE_MAP [input_tensor .type .multiArrayType .dataType ],
44+ }
45+ for input_tensor in self .model ._spec .description .input
46+ }
47+ elif sources == 'compiled' :
48+ assert os .path .exists (model_path ) and model_path .endswith (".mlmodelc" )
49+
50+ self .model = ct .models .CompiledMLModel (model_path , ct .ComputeUnit [compute_unit ])
51+
52+ # Grab expected inputs from metadata.json
53+ with open (os .path .join (model_path , 'metadata.json' ), 'r' ) as f :
54+ config = json .load (f )[0 ]
55+
56+ self .expected_inputs = {
57+ input_tensor ['name' ]: {
58+ "shape" : tuple (eval (input_tensor ['shape' ])),
59+ "dtype" : np .dtype (input_tensor ['dataType' ].lower ()),
60+ }
61+ for input_tensor in config ['inputSchema' ]
62+ }
63+ else :
64+ raise ValueError (f'Expected `packages` or `compiled` for sources, received { sources } ' )
65+
3266 load_time = time .time () - start
3367 logger .info (f"Done. Took { load_time :.1f} seconds." )
3468
@@ -38,21 +72,6 @@ def __init__(self, model_path, compute_unit):
3872 "The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
3973 )
4074
41-
42- DTYPE_MAP = {
43- 65552 : np .float16 ,
44- 65568 : np .float32 ,
45- 131104 : np .int32 ,
46- }
47-
48- self .expected_inputs = {
49- input_tensor .name : {
50- "shape" : tuple (input_tensor .type .multiArrayType .shape ),
51- "dtype" : DTYPE_MAP [input_tensor .type .multiArrayType .dataType ],
52- }
53- for input_tensor in self .model ._spec .description .input
54- }
55-
5675 def _verify_inputs (self , ** kwargs ):
5776 for k , v in kwargs .items ():
5877 if k in self .expected_inputs :
@@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
7291 f"Expected shape { expected_shape } , got { v .shape } for input: { k } "
7392 )
7493 else :
75- raise ValueError ("Received unexpected input kwarg: {k}" )
94+ raise ValueError (f "Received unexpected input kwarg: { k } " )
7695
7796 def __call__ (self , ** kwargs ):
7897 self ._verify_inputs (** kwargs )
@@ -82,21 +101,77 @@ def __call__(self, **kwargs):
82101LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
83102
84103
85- def _load_mlpackage (submodule_name , mlpackages_dir , model_version ,
86- compute_unit ):
87- """ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
104+ def get_resource_type (resources_dir : str ) -> str :
105+ """
106+ Detect resource type based on filepath extensions.
107+ returns:
108+ `packages`: for .mlpackage resources
109+ 'compiled`: for .mlmodelc resources
88110 """
89- logger . info ( f"Loading { submodule_name } mlpackage" )
111+ directories = [ f for f in os . listdir ( resources_dir ) if os . path . isdir ( os . path . join ( resources_dir , f ))]
90112
91- fname = f"Stable_Diffusion_version_{ model_version } _{ submodule_name } .mlpackage" .replace (
92- "/" , "_" )
93- mlpackage_path = os .path .join (mlpackages_dir , fname )
113+ # consider directories ending with extension
114+ extensions = set ([os .path .splitext (e )[1 ] for e in directories if os .path .splitext (e )[1 ]])
94115
95- if not os .path .exists (mlpackage_path ):
96- raise FileNotFoundError (
97- f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
116+ # if one extension present we may be able to infer sources type
117+ if len (set (extensions )) == 1 :
118+ extension = extensions .pop ()
119+ else :
120+ raise ValueError (f'Multiple file extensions found at { resources_dir } .'
121+ f'Cannot infer resource type from contents.' )
122+
123+ if extension == '.mlpackage' :
124+ sources = 'packages'
125+ elif extension == '.mlmodelc' :
126+ sources = 'compiled'
127+ else :
128+ raise ValueError (f'Did not find .mlpackage or .mlmodelc at { resources_dir } ' )
129+
130+ return sources
131+
132+
133+ def _load_mlpackage (submodule_name ,
134+ mlpackages_dir ,
135+ model_version ,
136+ compute_unit ,
137+ sources = None ):
138+ """
139+ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
140+
141+ """
142+
143+ # if sources not provided, attempt to infer `packages` or `compiled` from the
144+ # resources directory
145+ if sources is None :
146+ sources = get_resource_type (mlpackages_dir )
147+
148+ if sources == 'packages' :
149+ logger .info (f"Loading { submodule_name } mlpackage" )
150+ fname = f"Stable_Diffusion_version_{ model_version } _{ submodule_name } .mlpackage" .replace (
151+ "/" , "_" )
152+ mlpackage_path = os .path .join (mlpackages_dir , fname )
153+
154+ if not os .path .exists (mlpackage_path ):
155+ raise FileNotFoundError (
156+ f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
157+
158+ elif sources == 'compiled' :
159+ logger .info (f"Loading { submodule_name } mlmodelc" )
160+
161+ # FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
162+ submodule_names = ["text_encoder" , "text_encoder_2" , "unet" , "vae_decoder" ]
163+ compiled_names = ['TextEncoder' , 'TextEncoder2' , 'Unet' , 'VAEDecoder' , 'VAEEncoder' ]
164+ name_map = dict (zip (submodule_names , compiled_names ))
165+
166+ cname = name_map [submodule_name ] + '.mlmodelc'
167+ mlpackage_path = os .path .join (mlpackages_dir , cname )
168+
169+ if not os .path .exists (mlpackage_path ):
170+ raise FileNotFoundError (
171+ f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
172+
173+ return CoreMLModel (mlpackage_path , compute_unit , sources = sources )
98174
99- return CoreMLModel (mlpackage_path , compute_unit )
100175
101176def _load_mlpackage_controlnet (mlpackages_dir , model_version , compute_unit ):
102177 """ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
@@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
115190
116191 return CoreMLModel (mlpackage_path , compute_unit )
117192
193+
118194def get_available_compute_units ():
119195 return tuple (cu for cu in ct .ComputeUnit ._member_names_ )
0 commit comments