Skip to content

Commit aca5c6d

Browse files
authored
[WIP] Load text_model.embeddings.position_ids outsude state_dict (#3829)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update ## Description In transformers 4.31.0 `text_model.embeddings.position_ids` no longer part of state_dict. Fix untested as can't run right now but should be correct. Also need to check how transformers 4.30.2 works with this fix. ## Related Tickets & Documents huggingface/transformers@8e5d161#diff-7f53db5caa73a4cbeb0dca3b396e3d52f30f025b8c48d4daf51eb7abb6e2b949R191 https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer ## QA Instructions, Screenshots, Recordings ``` File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\invokeai\backend\model_management\convert_ckpt_to_diffusers.py", line 844, in convert_ldm_clip_checkpoint text_model.load_state_dict(text_model_dict) File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for CLIPTextModel: Unexpected key(s) in state_dict: "text_model.embeddings.position_ids". ```
2 parents 23f0c70 + f932047 commit aca5c6d

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

invokeai/app/services/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def emit_model_load_completed(
141141
model_type=model_type,
142142
submodel=submodel,
143143
hash=model_info.hash,
144-
location=model_info.location,
144+
location=str(model_info.location),
145145
precision=str(model_info.precision),
146146
),
147147
)

invokeai/backend/model_management/convert_ckpt_to_diffusers.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import warnings
2222
from pathlib import Path
2323
from typing import Union
24+
from packaging import version
2425

2526
import torch
2627
from safetensors.torch import load_file
@@ -63,6 +64,7 @@
6364
StableDiffusionSafetyChecker,
6465
)
6566
from diffusers.utils import is_safetensors_available
67+
import transformers
6668
from transformers import (
6769
AutoFeatureExtractor,
6870
BertTokenizerFast,
@@ -841,7 +843,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
841843
key
842844
]
843845

844-
text_model.load_state_dict(text_model_dict)
846+
# transformers 4.31.0 and higher - this key no longer in state dict
847+
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
848+
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
849+
text_model.load_state_dict(text_model_dict)
850+
if position_ids is not None:
851+
text_model.text_model.embeddings.position_ids.copy_(position_ids)
852+
853+
# transformers 4.30.2 and lower - position_ids is part of state_dict
854+
else:
855+
text_model.load_state_dict(text_model_dict)
845856

846857
return text_model
847858

@@ -947,7 +958,16 @@ def convert_open_clip_checkpoint(checkpoint):
947958

948959
text_model_dict[new_key] = checkpoint[key]
949960

950-
text_model.load_state_dict(text_model_dict)
961+
# transformers 4.31.0 and higher - this key no longer in state dict
962+
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
963+
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
964+
text_model.load_state_dict(text_model_dict)
965+
if position_ids is not None:
966+
text_model.text_model.embeddings.position_ids.copy_(position_ids)
967+
968+
# transformers 4.30.2 and lower - position_ids is part of state_dict
969+
else:
970+
text_model.load_state_dict(text_model_dict)
951971

952972
return text_model
953973

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ dependencies = [
8383
"torchvision>=0.14.1",
8484
"torchmetrics==0.11.4",
8585
"torchsde==0.2.5",
86-
"transformers==4.30.2",
86+
"transformers~=4.31.0",
8787
"uvicorn[standard]==0.21.1",
8888
"windows-curses; sys_platform=='win32'",
8989
]

0 commit comments

Comments
 (0)