Skip to content

Commit fb5468a

Browse files
Mishig Davaadorjpatil-suraj
andauthored
Add init_weights method to FlaxMixin (#513)
* Add `init_weights` method to `FlaxMixin` * Rn `random_state` -> `shape_state` * `PRNGKey(0)` for `jax.eval_shape` * No allow mismatched sizes * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * docstring diffusers Co-authored-by: Suraj Patil <[email protected]>
1 parent d144c46 commit fb5468a

File tree

1 file changed

+70
-5
lines changed

1 file changed

+70
-5
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax
2121
import jax.numpy as jnp
2222
import msgpack.exceptions
23-
from flax.core.frozen_dict import FrozenDict
23+
from flax.core.frozen_dict import FrozenDict, unfreeze
2424
from flax.serialization import from_bytes, to_bytes
2525
from flax.traverse_util import flatten_dict, unflatten_dict
2626
from huggingface_hub import hf_hub_download
@@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
183183
```"""
184184
return self._cast_floating_to(params, jnp.float16, mask)
185185

186+
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
187+
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
188+
186189
@classmethod
187190
def from_pretrained(
188191
cls,
@@ -227,10 +230,6 @@ def from_pretrained(
227230
cache_dir (`Union[str, os.PathLike]`, *optional*):
228231
Path to a directory in which a downloaded pretrained model configuration should be cached if the
229232
standard cache should not be used.
230-
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
231-
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
232-
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
233-
checkpoint with 3 labels).
234233
force_download (`bool`, *optional*, defaults to `False`):
235234
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
236235
cached versions if they exist.
@@ -394,6 +393,72 @@ def from_pretrained(
394393
# flatten dicts
395394
state = flatten_dict(state)
396395

396+
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
397+
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
398+
399+
shape_state = flatten_dict(unfreeze(params_shape_tree))
400+
401+
missing_keys = required_params - set(state.keys())
402+
unexpected_keys = set(state.keys()) - required_params
403+
404+
if missing_keys:
405+
logger.warning(
406+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
407+
"Make sure to call model.init_weights to initialize the missing weights."
408+
)
409+
cls._missing_keys = missing_keys
410+
411+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
412+
# matching the weights in the model.
413+
mismatched_keys = []
414+
for key in state.keys():
415+
if key in shape_state and state[key].shape != shape_state[key].shape:
416+
raise ValueError(
417+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
418+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
419+
)
420+
421+
# remove unexpected keys to not be saved again
422+
for unexpected_key in unexpected_keys:
423+
del state[unexpected_key]
424+
425+
if len(unexpected_keys) > 0:
426+
logger.warning(
427+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
428+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
429+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
430+
" with another architecture."
431+
)
432+
else:
433+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
434+
435+
if len(missing_keys) > 0:
436+
logger.warning(
437+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
438+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
439+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
440+
)
441+
elif len(mismatched_keys) == 0:
442+
logger.info(
443+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
444+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
445+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
446+
" training."
447+
)
448+
if len(mismatched_keys) > 0:
449+
mismatched_warning = "\n".join(
450+
[
451+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
452+
for key, shape1, shape2 in mismatched_keys
453+
]
454+
)
455+
logger.warning(
456+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
457+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
458+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
459+
" to use it for predictions and inference."
460+
)
461+
397462
# dictionary of key: dtypes for the model params
398463
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
399464
# extract keys of parameters not in jnp.float32

0 commit comments

Comments
 (0)