|
20 | 20 | import jax |
21 | 21 | import jax.numpy as jnp |
22 | 22 | import msgpack.exceptions |
23 | | -from flax.core.frozen_dict import FrozenDict |
| 23 | +from flax.core.frozen_dict import FrozenDict, unfreeze |
24 | 24 | from flax.serialization import from_bytes, to_bytes |
25 | 25 | from flax.traverse_util import flatten_dict, unflatten_dict |
26 | 26 | from huggingface_hub import hf_hub_download |
@@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): |
183 | 183 | ```""" |
184 | 184 | return self._cast_floating_to(params, jnp.float16, mask) |
185 | 185 |
|
| 186 | + def init_weights(self, rng: jax.random.PRNGKey) -> Dict: |
| 187 | + raise NotImplementedError(f"init_weights method has to be implemented for {self}") |
| 188 | + |
186 | 189 | @classmethod |
187 | 190 | def from_pretrained( |
188 | 191 | cls, |
@@ -227,10 +230,6 @@ def from_pretrained( |
227 | 230 | cache_dir (`Union[str, os.PathLike]`, *optional*): |
228 | 231 | Path to a directory in which a downloaded pretrained model configuration should be cached if the |
229 | 232 | standard cache should not be used. |
230 | | - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): |
231 | | - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size |
232 | | - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a |
233 | | - checkpoint with 3 labels). |
234 | 233 | force_download (`bool`, *optional*, defaults to `False`): |
235 | 234 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
236 | 235 | cached versions if they exist. |
@@ -394,6 +393,72 @@ def from_pretrained( |
394 | 393 | # flatten dicts |
395 | 394 | state = flatten_dict(state) |
396 | 395 |
|
| 396 | + params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) |
| 397 | + required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) |
| 398 | + |
| 399 | + shape_state = flatten_dict(unfreeze(params_shape_tree)) |
| 400 | + |
| 401 | + missing_keys = required_params - set(state.keys()) |
| 402 | + unexpected_keys = set(state.keys()) - required_params |
| 403 | + |
| 404 | + if missing_keys: |
| 405 | + logger.warning( |
| 406 | + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " |
| 407 | + "Make sure to call model.init_weights to initialize the missing weights." |
| 408 | + ) |
| 409 | + cls._missing_keys = missing_keys |
| 410 | + |
| 411 | + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not |
| 412 | + # matching the weights in the model. |
| 413 | + mismatched_keys = [] |
| 414 | + for key in state.keys(): |
| 415 | + if key in shape_state and state[key].shape != shape_state[key].shape: |
| 416 | + raise ValueError( |
| 417 | + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " |
| 418 | + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " |
| 419 | + ) |
| 420 | + |
| 421 | + # remove unexpected keys to not be saved again |
| 422 | + for unexpected_key in unexpected_keys: |
| 423 | + del state[unexpected_key] |
| 424 | + |
| 425 | + if len(unexpected_keys) > 0: |
| 426 | + logger.warning( |
| 427 | + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" |
| 428 | + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" |
| 429 | + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" |
| 430 | + " with another architecture." |
| 431 | + ) |
| 432 | + else: |
| 433 | + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") |
| 434 | + |
| 435 | + if len(missing_keys) > 0: |
| 436 | + logger.warning( |
| 437 | + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" |
| 438 | + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" |
| 439 | + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
| 440 | + ) |
| 441 | + elif len(mismatched_keys) == 0: |
| 442 | + logger.info( |
| 443 | + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" |
| 444 | + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" |
| 445 | + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" |
| 446 | + " training." |
| 447 | + ) |
| 448 | + if len(mismatched_keys) > 0: |
| 449 | + mismatched_warning = "\n".join( |
| 450 | + [ |
| 451 | + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
| 452 | + for key, shape1, shape2 in mismatched_keys |
| 453 | + ] |
| 454 | + ) |
| 455 | + logger.warning( |
| 456 | + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" |
| 457 | + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" |
| 458 | + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" |
| 459 | + " to use it for predictions and inference." |
| 460 | + ) |
| 461 | + |
397 | 462 | # dictionary of key: dtypes for the model params |
398 | 463 | param_dtypes = jax.tree_map(lambda x: x.dtype, state) |
399 | 464 | # extract keys of parameters not in jnp.float32 |
|
0 commit comments