Skip to content

Commit 2e807e9

Browse files
committed
Fix a bug where compile_config gets deleted
1 parent 078ff68 commit 2e807e9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def save_pretrained(
755755

756756
output_config_file = os.path.join(save_directory, config_file_name)
757757

758-
self.to_json_file(output_config_file, use_diff=True)
758+
self.to_json_file(output_config_file, use_diff=True, pop_keys=["compile_config"])
759759
logger.info(f"Configuration saved in {output_config_file}")
760760

761761
if push_to_hub:
@@ -1022,16 +1022,16 @@ def to_dict(self) -> dict[str, Any]:
10221022
del output["_commit_hash"]
10231023
if "_original_object_hash" in output:
10241024
del output["_original_object_hash"]
1025-
if "compile_config" in output:
1026-
del output["compile_config"]
10271025

10281026
# Transformers version when serializing this file
10291027
output["transformers_version"] = __version__
10301028

10311029
self.dict_dtype_to_str(output)
10321030
return output
10331031

1034-
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
1032+
def to_json_string(
1033+
self, use_diff: bool = True, ignore_metadata: bool = False, pop_keys: list[str] | None = None
1034+
) -> str:
10351035
"""
10361036
Serializes this instance to a JSON string.
10371037
@@ -1041,6 +1041,8 @@ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -
10411041
is serialized to JSON string.
10421042
ignore_metadata (`bool`, *optional*, defaults to `False`):
10431043
Whether to ignore the metadata fields present in the instance
1044+
pop_keys (`list[str]`, *optional*):
1045+
Keys to pop from the config dictionary before serializing
10441046
10451047
Returns:
10461048
`str`: String containing all the attributes that make up this configuration instance in JSON format.
@@ -1050,6 +1052,10 @@ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -
10501052
else:
10511053
config_dict = self.to_dict()
10521054

1055+
if pop_keys is not None:
1056+
for key in pop_keys:
1057+
config_dict.pop(key, None)
1058+
10531059
if ignore_metadata:
10541060
for metadata_field in METADATA_FIELDS:
10551061
config_dict.pop(metadata_field, None)
@@ -1075,7 +1081,9 @@ def convert_dataclass_to_dict(obj):
10751081

10761082
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
10771083

1078-
def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
1084+
def to_json_file(
1085+
self, json_file_path: str | os.PathLike, use_diff: bool = True, pop_keys: list[str] | None = None
1086+
) -> None:
10791087
"""
10801088
Save this instance to a JSON file.
10811089
@@ -1085,9 +1093,11 @@ def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True)
10851093
use_diff (`bool`, *optional*, defaults to `True`):
10861094
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
10871095
is serialized to JSON file.
1096+
pop_keys (`list[str]`, *optional*):
1097+
Keys to pop from the config dictionary before serializing
10881098
"""
10891099
with open(json_file_path, "w", encoding="utf-8") as writer:
1090-
writer.write(self.to_json_string(use_diff=use_diff))
1100+
writer.write(self.to_json_string(use_diff=use_diff, pop_keys=pop_keys))
10911101

10921102
@classmethod
10931103
def from_model_config(cls, model_config: PreTrainedConfig | dict) -> "GenerationConfig":

0 commit comments

Comments
 (0)