Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 48 additions & 126 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
import sys
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from typing import Any, List
from argparse import Namespace
from dataclasses import dataclass, field

import yaml
from transformers import HfArgumentParser
Expand All @@ -29,87 +27,23 @@


class YamlConfigParser:
def __init__(self, config_path: str = None, dataclasses: List[Any] = None):
self.config = None
def parse_and_set_env(self, config_path):
with open(config_path) as yaml_file:
config = yaml.safe_load(yaml_file)

if config_path is not None:
with open(config_path) as yaml_file:
self.config = yaml.safe_load(yaml_file)
else:
self.config = {}

if dataclasses is None:
dataclasses = []

# We create a dummy training args to compare the values before / after
# __post_init__
# Here we import `TrainingArguments` from the local level to not
# break TRL lazy imports.
from transformers import TrainingArguments

self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args")

self.parse_and_set_env()
self.merge_dataclasses(dataclasses)

def parse_and_set_env(self):
if "env" in self.config:
env_vars = self.config["env"]
if "env" in config:
env_vars = config.pop("env")
if isinstance(env_vars, dict):
for key, value in env_vars.items():
os.environ[key] = str(value)
else:
raise ValueError("`env` field should be a dict in the YAML file.")

def merge_dataclasses(self, dataclasses):
from transformers import TrainingArguments

dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses]

if len(self.config) > 0:
for i, dataclass in enumerate(dataclasses):
is_hf_training_args = False

for data_class_field in fields(dataclass):
# Get the field here
field_name = data_class_field.name
field_value = getattr(dataclass, field_name)

if not isinstance(dataclass, TrainingArguments) or not hasattr(
self._dummy_training_args, field_name
):
default_value = data_class_field.default
else:
default_value = (
getattr(self._dummy_training_args, field_name)
if field_name != "output_dir"
else field_name
)
is_hf_training_args = True

default_value_changed = field_value != default_value

if field_value is not None or field_name in self.config:
if field_name in self.config:
# In case the field value is not different from default, overwrite it
if not default_value_changed:
value_to_replace = self.config[field_name]
setattr(dataclasses_copy[i], field_name, value_to_replace)
# Otherwise do nothing

# Re-init `TrainingArguments` or derived class to handle all post-processing correctly
if is_hf_training_args:
ArgCls = type(dataclass)
init_signature = list(inspect.signature(ArgCls.__init__).parameters)
dict_dataclass = asdict(dataclasses_copy[i])
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
dataclasses_copy[i] = ArgCls(**new_dict_dataclass)

return dataclasses_copy

def to_string(self):
return config

def to_string(self, config):
final_string = """"""
for key, value in self.config.items():
for key, value in config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
value = str(value)
Expand Down Expand Up @@ -261,8 +195,7 @@ def __init__(self, parsers):
List of parsers.
"""
super().__init__(parsers)

self.config_parser = None
self.yaml_parser = YamlConfigParser()

def post_process_dataclasses(self, dataclasses):
# Apply additional post-processing in case some arguments needs a special
Expand All @@ -274,10 +207,7 @@ def post_process_dataclasses(self, dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
elif dataclass_obj.__class__.__name__ in (
"SFTScriptArguments",
"DPOScriptArguments",
):
elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"):
trl_args = dataclass_obj
else:
...
Expand All @@ -290,52 +220,44 @@ def post_process_dataclasses(self, dataclasses):

return dataclasses

def parse_args_and_config(self):
# Hack to force-replace the `output_dir` from the YAML file if one did not passed
# output_dir in the command line
def parse_args_and_config(self, return_remaining_strings=False):
yaml_config = None
if "--config" in sys.argv:
config_index = sys.argv.index("--config") + 1
config_path = sys.argv[config_index]
config_index = sys.argv.index("--config")

self.config_parser = YamlConfigParser(config_path)
output_dir = self.config_parser.config.get("output_dir")
_ = sys.argv.pop(config_index) # --config
config_path = sys.argv.pop(config_index) # path to config
yaml_config = self.yaml_parser.parse_and_set_env(config_path)

if output_dir is not None:
if "--output_dir" in sys.argv:
output_dir_index = sys.argv.index("--output_dir")
passed_output_dir = sys.argv[output_dir_index + 1]
self.config_parser.config["output_dir"] = passed_output_dir
else:
sys.argv.extend(["--output_dir", output_dir])

dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)

if len(dataclasses[-1]) > 0:
# It is expected that `config` is in that list but not ignored
# let's simply remove them
list_ignored = dataclasses[-1]
if "--config" in list_ignored:
config_index = list_ignored.index("--config") + 1
config_path = list_ignored[config_index]

list_ignored.remove(config_path)
list_ignored.remove("--config")
self.set_defaults_with_config(**yaml_config)

if len(list_ignored) > 0:
logger.warning(
f"Detected extra arguments that are going to be ignored: {list_ignored} - make sure to double check what you are doing"
)
outputs = self.parse_args_into_dataclasses(return_remaining_strings=return_remaining_strings)

# Pop the last element which should be the remaining strings
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
return dataclasses

def update_dataclasses_with_config(self, dataclasses):
for parser_dataclass in dataclasses:
if hasattr(parser_dataclass, "config") and self.config_parser is None:
self.config_parser = YamlConfigParser(parser_dataclass.config)
if yaml_config is None:
return outputs

if self.config_parser is not None:
dataclasses = self.config_parser.merge_dataclasses(dataclasses)
dataclasses = self.post_process_dataclasses(dataclasses)
return dataclasses
if return_remaining_strings:
# if we have extra yaml config and command line strings
# outputs[-1] is remaining command line strings
# outputs[-2] is remaining yaml config as Namespace
# combine them into remaining strings object
remaining_strings = outputs[-1] + [f"{key}: {value}" for key, value in vars(outputs[-2]).items()]
return outputs[:-2], remaining_strings
else:
# outputs[-1] is either remaining yaml config as Namespace or parsed config as Dataclass
if isinstance(outputs[-1], Namespace):
remaining_args = vars(outputs[-1])
raise ValueError(f"Some specified config arguments are not used by the TrlParser: {remaining_args}")

return outputs

def set_defaults_with_config(self, **kwargs):
"""Defaults we're setting with config allow us to change to required = False"""
self._defaults.update(kwargs)

# if these defaults match any existing arguments, replace
# the previous default on the object with the new one
for action in self._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
action.required = False