From b9d9c7b6557041404ba47362c00573122dfab041 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Fri, 14 Jun 2024 19:44:12 +0000 Subject: [PATCH 1/2] working trl parser with config correctly overrides yaml config with command line arguments adds return_remaining_strings when return_remaining_strings is False, raises error if yaml contains extra args that are not in the dataclasses simpler and cleaner than previous yaml parsing and merging addresses #1733 --- trl/commands/cli_utils.py | 176 +++++++++++--------------------------- 1 file changed, 49 insertions(+), 127 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index ce08d7b6309..bf997d3fcaf 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -13,13 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os import sys -from copy import deepcopy -from dataclasses import asdict, dataclass, field, fields -from typing import Any, List +from argparse import Namespace +from dataclasses import dataclass, field import yaml from transformers import HfArgumentParser @@ -29,87 +27,23 @@ class YamlConfigParser: - def __init__(self, config_path: str = None, dataclasses: List[Any] = None): - self.config = None + def parse_and_set_env(self, config_path): + with open(config_path) as yaml_file: + config = yaml.safe_load(yaml_file) - if config_path is not None: - with open(config_path) as yaml_file: - self.config = yaml.safe_load(yaml_file) - else: - self.config = {} - - if dataclasses is None: - dataclasses = [] - - # We create a dummy training args to compare the values before / after - # __post_init__ - # Here we import `TrainingArguments` from the local level to not - # break TRL lazy imports. - from transformers import TrainingArguments - - self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args") - - self.parse_and_set_env() - self.merge_dataclasses(dataclasses) - - def parse_and_set_env(self): - if "env" in self.config: - env_vars = self.config["env"] + if "env" in config: + env_vars = config.pop("env") if isinstance(env_vars, dict): for key, value in env_vars.items(): os.environ[key] = str(value) else: raise ValueError("`env` field should be a dict in the YAML file.") - def merge_dataclasses(self, dataclasses): - from transformers import TrainingArguments - - dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses] - - if len(self.config) > 0: - for i, dataclass in enumerate(dataclasses): - is_hf_training_args = False - - for data_class_field in fields(dataclass): - # Get the field here - field_name = data_class_field.name - field_value = getattr(dataclass, field_name) - - if not isinstance(dataclass, TrainingArguments) or not hasattr( - self._dummy_training_args, field_name - ): - default_value = data_class_field.default - else: - default_value = ( - getattr(self._dummy_training_args, field_name) - if field_name != "output_dir" - else field_name - ) - is_hf_training_args = True - - default_value_changed = field_value != default_value - - if field_value is not None or field_name in self.config: - if field_name in self.config: - # In case the field value is not different from default, overwrite it - if not default_value_changed: - value_to_replace = self.config[field_name] - setattr(dataclasses_copy[i], field_name, value_to_replace) - # Otherwise do nothing - - # Re-init `TrainingArguments` or derived class to handle all post-processing correctly - if is_hf_training_args: - ArgCls = type(dataclass) - init_signature = list(inspect.signature(ArgCls.__init__).parameters) - dict_dataclass = asdict(dataclasses_copy[i]) - new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature} - dataclasses_copy[i] = ArgCls(**new_dict_dataclass) - - return dataclasses_copy - - def to_string(self): + return config + + def to_string(self, config): final_string = """""" - for key, value in self.config.items(): + for key, value in config.items(): if isinstance(value, (dict, list)): if len(value) != 0: value = str(value) @@ -249,7 +183,7 @@ class ChatArguments: use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) -class TrlParser(HfArgumentParser): +class TRLParser(HfArgumentParser): def __init__(self, parsers): """ The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config @@ -261,8 +195,7 @@ def __init__(self, parsers): List of parsers. """ super().__init__(parsers) - - self.config_parser = None + self.yaml_parser = YamlConfigParser() def post_process_dataclasses(self, dataclasses): # Apply additional post-processing in case some arguments needs a special @@ -274,10 +207,7 @@ def post_process_dataclasses(self, dataclasses): if dataclass_obj.__class__.__name__ == "TrainingArguments": training_args = dataclass_obj training_args_index = i - elif dataclass_obj.__class__.__name__ in ( - "SFTScriptArguments", - "DPOScriptArguments", - ): + elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"): trl_args = dataclass_obj else: ... @@ -290,52 +220,44 @@ def post_process_dataclasses(self, dataclasses): return dataclasses - def parse_args_and_config(self): - # Hack to force-replace the `output_dir` from the YAML file if one did not passed - # output_dir in the command line + def parse_args_and_config(self, return_remaining_strings=False): + yaml_config = None if "--config" in sys.argv: - config_index = sys.argv.index("--config") + 1 - config_path = sys.argv[config_index] + config_index = sys.argv.index("--config") - self.config_parser = YamlConfigParser(config_path) - output_dir = self.config_parser.config.get("output_dir") + _ = sys.argv.pop(config_index) # --config + config_path = sys.argv.pop(config_index) # path to config + yaml_config = self.yaml_parser.parse_and_set_env(config_path) - if output_dir is not None: - if "--output_dir" in sys.argv: - output_dir_index = sys.argv.index("--output_dir") - passed_output_dir = sys.argv[output_dir_index + 1] - self.config_parser.config["output_dir"] = passed_output_dir - else: - sys.argv.extend(["--output_dir", output_dir]) - - dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True) - - if len(dataclasses[-1]) > 0: - # It is expected that `config` is in that list but not ignored - # let's simply remove them - list_ignored = dataclasses[-1] - if "--config" in list_ignored: - config_index = list_ignored.index("--config") + 1 - config_path = list_ignored[config_index] - - list_ignored.remove(config_path) - list_ignored.remove("--config") + self.set_defaults_with_config(**yaml_config) - if len(list_ignored) > 0: - logger.warning( - f"Detected extra arguments that are going to be ignored: {list_ignored} - make sure to double check what you are doing" - ) + outputs = self.parse_args_into_dataclasses(return_remaining_strings=return_remaining_strings) - # Pop the last element which should be the remaining strings - dataclasses = self.update_dataclasses_with_config(dataclasses[:-1]) - return dataclasses - - def update_dataclasses_with_config(self, dataclasses): - for parser_dataclass in dataclasses: - if hasattr(parser_dataclass, "config") and self.config_parser is None: - self.config_parser = YamlConfigParser(parser_dataclass.config) + if yaml_config is None: + return outputs - if self.config_parser is not None: - dataclasses = self.config_parser.merge_dataclasses(dataclasses) - dataclasses = self.post_process_dataclasses(dataclasses) - return dataclasses + if return_remaining_strings: + # if we have extra yaml config and command line strings + # outputs[-1] is remaining command line strings + # outputs[-2] is remaining yaml config as Namespace + # combine them into remaining strings object + remaining_strings = outputs[-1] + [f"{key}: {value}" for key, value in vars(outputs[-2]).items()] + return outputs[:-2], remaining_strings + else: + # outputs[-1] is either remaining yaml config as Namespace or parsed config as Dataclass + if isinstance(outputs[-1], Namespace): + remaining_args = vars(outputs[-1]) + raise ValueError(f"Some specified config arguments are not used by the TRLParser: {remaining_args}") + + return outputs + + def set_defaults_with_config(self, **kwargs): + """Defaults we're setting with config allow us to change to required = False""" + self._defaults.update(kwargs) + + # if these defaults match any existing arguments, replace + # the previous default on the object with the new one + for action in self._actions: + if action.dest in kwargs: + action.default = kwargs[action.dest] + action.required = False From e5a24b04bb364be1f49d2403ae9f03adfdb8d156 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Fri, 14 Jun 2024 19:52:39 +0000 Subject: [PATCH 2/2] lowercase trlparser --- trl/commands/cli_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index bf997d3fcaf..7cd9258fcb1 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -183,7 +183,7 @@ class ChatArguments: use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) -class TRLParser(HfArgumentParser): +class TrlParser(HfArgumentParser): def __init__(self, parsers): """ The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config @@ -247,7 +247,7 @@ def parse_args_and_config(self, return_remaining_strings=False): # outputs[-1] is either remaining yaml config as Namespace or parsed config as Dataclass if isinstance(outputs[-1], Namespace): remaining_args = vars(outputs[-1]) - raise ValueError(f"Some specified config arguments are not used by the TRLParser: {remaining_args}") + raise ValueError(f"Some specified config arguments are not used by the TrlParser: {remaining_args}") return outputs