Skip to content

Commit fd59fc1

Browse files
younesbelkadaAndrechang
authored andcommitted
[bnb] Fix bnb config json serialization (#24137)
* fix bnb config json serialization * forward contrib credits from discussions --------- Co-authored-by: Andrechang <[email protected]>
1 parent a272e41 commit fd59fc1

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

src/transformers/configuration_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,13 @@ def to_diff_dict(self) -> Dict[str, Any]:
784784
):
785785
serializable_config_dict[key] = value
786786

787+
if hasattr(self, "quantization_config"):
788+
serializable_config_dict["quantization_config"] = (
789+
self.quantization_config.to_dict()
790+
if not isinstance(self.quantization_config, dict)
791+
else self.quantization_config
792+
)
793+
787794
self.dict_torch_dtype_to_str(serializable_config_dict)
788795

789796
return serializable_config_dict

tests/bitsandbytes/test_4bit.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,19 @@ def tearDown(self):
111111
gc.collect()
112112
torch.cuda.empty_cache()
113113

114+
def test_quantization_config_json_serialization(self):
115+
r"""
116+
A simple test to check if the quantization config is correctly serialized and deserialized
117+
"""
118+
config = self.model_4bit.config
119+
120+
self.assertTrue(hasattr(config, "quantization_config"))
121+
122+
_ = config.to_dict()
123+
_ = config.to_diff_dict()
124+
125+
_ = config.to_json_string()
126+
114127
def test_memory_footprint(self):
115128
r"""
116129
A simple test to check if the model conversion has been done correctly by checking on the

tests/bitsandbytes/test_mixed_int8.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def tearDown(self):
118118
gc.collect()
119119
torch.cuda.empty_cache()
120120

121+
def test_quantization_config_json_serialization(self):
122+
r"""
123+
A simple test to check if the quantization config is correctly serialized and deserialized
124+
"""
125+
config = self.model_8bit.config
126+
127+
self.assertTrue(hasattr(config, "quantization_config"))
128+
129+
_ = config.to_dict()
130+
_ = config.to_diff_dict()
131+
132+
_ = config.to_json_string()
133+
121134
def test_memory_footprint(self):
122135
r"""
123136
A simple test to check if the model conversion has been done correctly by checking on the

0 commit comments

Comments
 (0)