Skip to content

Commit 039af39

Browse files
nirda7Eran Geva
authored andcommitted
[SW-194200] Save scale file only with new scales
Change-Id: I14a4ef94d188b13c2fbf4ea77d2b42cb5bd6d952
1 parent 4f8b257 commit 039af39

File tree

1 file changed

+3
-1
lines changed
  • neural_compressor/torch/algorithms/fp8_quant/_core

1 file changed

+3
-1
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/scale.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def get_config(
103103
)
104104
scales = convert_scales_to_tensors_dict(scales_obj, scales_file_format, params["hp_dtype"])
105105
model_dict = dict(model.named_modules())
106+
save_file = False
106107
for mname in mod_list:
107108
mod = model_dict[mname]
108109
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
@@ -123,6 +124,7 @@ def get_config(
123124
scales_obj[mname] = ModuleConfig(
124125
**format_functions_rec((torch.Tensor, scales_file_format))(scales[mname].__dict__)
125126
)
127+
save_file = True
126128

127129
logger.debug(
128130
"Preparing quantization functions for layer %s layer_type=%s",
@@ -138,7 +140,7 @@ def get_config(
138140
params,
139141
)
140142
qconfig[mname] = mod_extra_config
141-
if scales_file is not None:
143+
if save_file and scales_file is not None:
142144
save_scales(model, scales_obj, scales_file_format, scales_file + ".npz")
143145
save_scales(model, scales_obj, scales_file_format, scales_file + ".json")
144146
return qconfig

0 commit comments

Comments
 (0)