Skip to content

Commit 38550b0

Browse files
authored
Fix transformers_trainer save_model logic (#580)
Change in #546 breaks some conditions of saving models in hf training example, this PR fix the issue by reverting change in checkpoint saving logic and update the config separately. Related pipeline failure: https://gitlab-master.nvidia.com/omniml/modelopt/-/jobs/233869405 ## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Fridah-nv <[email protected]>
1 parent 1aaa77d commit 38550b0

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -262,32 +262,30 @@ def train(self, *args, **kwargs):
262262

263263
def save_model(self, *args, **kwargs):
264264
"""Save the quantized model."""
265-
if not self.is_in_train:
266-
if (
267-
self.is_fsdp_enabled
268-
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
269-
):
265+
if (
266+
(not self.is_in_train)
267+
and self.is_fsdp_enabled
268+
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
269+
):
270+
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
271+
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
272+
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
273+
outputs = super().save_model(*args, **kwargs)
274+
if torch.distributed.is_initialized():
275+
torch.distributed.barrier()
276+
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
270277
print_rank_0(
271-
"Setting state_dict_type to FULL_STATE_DICT for final checkpoint save."
278+
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
279+
"model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
272280
)
273-
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
274-
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
275-
outputs = super().save_model(*args, **kwargs)
276-
if torch.distributed.is_initialized():
277-
torch.distributed.barrier()
278-
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
279-
print_rank_0(
280-
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
281-
"model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
282-
)
283-
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
284-
if self.args.should_save:
285-
out_dir = args[0]
286-
# FSDP may upcast parameter dtype to float32 during mixed-precision training,
287-
# we convert it back to original dtype by updating `torch-dtype` in `config.json`
288-
self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
281+
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
289282
else:
290283
outputs = super().save_model(*args, **kwargs)
284+
if (not self.is_in_train) and self.args.should_save:
285+
out_dir = args[0]
286+
# FSDP may upcast parameter dtype to float32 during mixed-precision training,
287+
# we convert it back to original dtype by updating `torch-dtype` in `config.json`
288+
self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
291289
return outputs
292290

293291
def _load_best_model(self, *args, **kwargs):

0 commit comments

Comments
 (0)