Skip to content

Commit 18ec3e9

Browse files
Fix torch_dtype handling in {DPO,SFT}Trainer when provided via CLI (huggingface#1807)
* Fix `torch_dtype` handling through CLI The `torch_dtype` is not properly handled when provided via the TRL CLI since it's provided initially as a string, but is then casted to `torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means that those trainers should handle the scenario where `torch_dtype` is a `torch.dtype` too. * Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py` * Forward contribution credits * Run `make precommit` --------- Co-authored-by: Tash Srivastava <[email protected]>
1 parent 5e0ca32 commit 18ec3e9

File tree

4 files changed

+144
-15
lines changed

4 files changed

+144
-15
lines changed

tests/test_dpo_trainer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import tempfile
1516
import unittest
1617

18+
import pytest
1719
import torch
1820
from datasets import Dataset, features
1921
from parameterized import parameterized
@@ -840,6 +842,68 @@ def test_dpo_lora_force_use_ref(self):
840842
# train the model
841843
trainer.train()
842844

845+
def test_dpo_trainer_torch_dtype(self):
846+
# See https://github.com/huggingface/trl/issues/1751
847+
dummy_dataset = self._init_dummy_dataset()
848+
with tempfile.TemporaryDirectory() as tmp_dir:
849+
dpo_config = DPOConfig(
850+
output_dir=tmp_dir,
851+
per_device_train_batch_size=2,
852+
max_steps=1,
853+
model_init_kwargs={"torch_dtype": "float16"},
854+
ref_model_init_kwargs={"torch_dtype": "float16"},
855+
)
856+
857+
trainer = DPOTrainer(
858+
model=self.model_id,
859+
ref_model=self.model_id,
860+
tokenizer=self.tokenizer,
861+
args=dpo_config,
862+
train_dataset=dummy_dataset,
863+
)
864+
assert trainer.model.config.torch_dtype == torch.float16
865+
assert trainer.ref_model.config.torch_dtype == torch.float16
866+
867+
# Now test when `torch_dtype` is provided but is wrong to either the model or the ref_model
868+
with tempfile.TemporaryDirectory() as tmp_dir:
869+
dpo_config = DPOConfig(
870+
output_dir=tmp_dir,
871+
per_device_train_batch_size=2,
872+
max_steps=1,
873+
model_init_kwargs={"torch_dtype": -1},
874+
)
875+
876+
with pytest.raises(
877+
ValueError,
878+
match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
879+
):
880+
_ = DPOTrainer(
881+
model=self.model_id,
882+
tokenizer=self.tokenizer,
883+
args=dpo_config,
884+
train_dataset=dummy_dataset,
885+
)
886+
887+
with tempfile.TemporaryDirectory() as tmp_dir:
888+
dpo_config = DPOConfig(
889+
output_dir=tmp_dir,
890+
per_device_train_batch_size=2,
891+
max_steps=1,
892+
ref_model_init_kwargs={"torch_dtype": -1},
893+
)
894+
895+
with pytest.raises(
896+
ValueError,
897+
match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
898+
):
899+
_ = DPOTrainer(
900+
model=self.model_id,
901+
ref_model=self.model_id,
902+
tokenizer=self.tokenizer,
903+
args=dpo_config,
904+
train_dataset=dummy_dataset,
905+
)
906+
843907
def test_dpo_loss_alpha_div_f(self):
844908
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
845909
tokenizer = AutoTokenizer.from_pretrained(model_id)

tests/test_sft_trainer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,3 +1273,45 @@ def __call__(self, examples):
12731273
assert trainer.state.log_history[0]["eval_loss"] is not None
12741274

12751275
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
1276+
1277+
def test_sft_trainer_torch_dtype(self):
1278+
# See https://github.com/huggingface/trl/issues/1751
1279+
with tempfile.TemporaryDirectory() as tmp_dir:
1280+
training_args = SFTConfig(
1281+
output_dir=tmp_dir,
1282+
eval_strategy="steps",
1283+
max_steps=4,
1284+
eval_steps=2,
1285+
save_steps=2,
1286+
per_device_train_batch_size=2,
1287+
model_init_kwargs={"torch_dtype": torch.float16},
1288+
)
1289+
trainer = SFTTrainer(
1290+
model=self.model_id,
1291+
args=training_args,
1292+
train_dataset=self.train_dataset,
1293+
eval_dataset=self.eval_dataset,
1294+
)
1295+
assert trainer.model.config.torch_dtype == torch.float16
1296+
1297+
# Now test when `torch_dtype` is provided but is wrong
1298+
with tempfile.TemporaryDirectory() as tmp_dir:
1299+
training_args = SFTConfig(
1300+
output_dir=tmp_dir,
1301+
eval_strategy="steps",
1302+
max_steps=4,
1303+
eval_steps=2,
1304+
save_steps=2,
1305+
per_device_train_batch_size=2,
1306+
model_init_kwargs={"torch_dtype": -1},
1307+
)
1308+
with pytest.raises(
1309+
ValueError,
1310+
match="Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
1311+
):
1312+
_ = SFTTrainer(
1313+
model=self.model_id,
1314+
args=training_args,
1315+
train_dataset=self.train_dataset,
1316+
eval_dataset=self.eval_dataset,
1317+
)

trl/trainer/dpo_trainer.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,19 @@ def __init__(
181181
)
182182
else:
183183
model_init_kwargs = args.model_init_kwargs
184-
model_init_kwargs["torch_dtype"] = (
185-
model_init_kwargs["torch_dtype"]
186-
if model_init_kwargs["torch_dtype"] in ["auto", None]
187-
else getattr(torch, model_init_kwargs["torch_dtype"])
188-
)
184+
185+
torch_dtype = model_init_kwargs["torch_dtype"]
186+
if torch_dtype is not None:
187+
# Convert to `torch.dtype` if an str is passed
188+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
189+
torch_dtype = getattr(torch, torch_dtype)
190+
191+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
192+
raise ValueError(
193+
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
194+
)
195+
196+
model_init_kwargs["torch_dtype"] = torch_dtype
189197

190198
if ref_model_init_kwargs is not None:
191199
warnings.warn(
@@ -201,11 +209,18 @@ def __init__(
201209
)
202210
else:
203211
ref_model_init_kwargs = args.ref_model_init_kwargs
204-
ref_model_init_kwargs["torch_dtype"] = (
205-
ref_model_init_kwargs["torch_dtype"]
206-
if ref_model_init_kwargs["torch_dtype"] in ["auto", None]
207-
else getattr(torch, ref_model_init_kwargs["torch_dtype"])
208-
)
212+
torch_dtype = ref_model_init_kwargs["torch_dtype"]
213+
if torch_dtype is not None:
214+
# Convert to `torch.dtype` if an str is passed
215+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
216+
torch_dtype = getattr(torch, torch_dtype)
217+
218+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
219+
raise ValueError(
220+
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
221+
)
222+
223+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
209224

210225
if isinstance(model, str):
211226
warnings.warn(

trl/trainer/sft_trainer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,19 @@ def __init__(
162162
raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
163163
else:
164164
model_init_kwargs = args.model_init_kwargs
165-
model_init_kwargs["torch_dtype"] = (
166-
model_init_kwargs["torch_dtype"]
167-
if model_init_kwargs["torch_dtype"] in ["auto", None]
168-
else getattr(torch, model_init_kwargs["torch_dtype"])
169-
)
165+
166+
torch_dtype = model_init_kwargs["torch_dtype"]
167+
if torch_dtype is not None:
168+
# Convert to `torch.dtype` if an str is passed
169+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
170+
torch_dtype = getattr(torch, torch_dtype)
171+
172+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
173+
raise ValueError(
174+
f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
175+
)
176+
177+
model_init_kwargs["torch_dtype"] = torch_dtype
170178

171179
if infinite is not None:
172180
warnings.warn(

0 commit comments

Comments
 (0)