Skip to content

Commit b6af2ed

Browse files
authored
add model_init_kwargs to training_args (#1787)
1 parent cd85b14 commit b6af2ed

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

examples/scripts/sft.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,19 @@
9494
console = Console()
9595

9696
################
97-
# Model & Tokenizer
97+
# Model init kwargs & Tokenizer
9898
################
99-
torch_dtype = (
100-
model_config.torch_dtype
101-
if model_config.torch_dtype in ["auto", None]
102-
else getattr(torch, model_config.torch_dtype)
103-
)
10499
quantization_config = get_quantization_config(model_config)
105100
model_kwargs = dict(
106101
revision=model_config.model_revision,
107102
trust_remote_code=model_config.trust_remote_code,
108103
attn_implementation=model_config.attn_implementation,
109-
torch_dtype=torch_dtype,
104+
torch_dtype=model_config.torch_dtype,
110105
use_cache=False if training_args.gradient_checkpointing else True,
111106
device_map=get_kbit_device_map() if quantization_config is not None else None,
112107
quantization_config=quantization_config,
113108
)
109+
training_args.model_init_kwargs = model_kwargs
114110
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
115111
tokenizer.pad_token = tokenizer.eos_token
116112

@@ -138,7 +134,6 @@
138134
with init_context:
139135
trainer = SFTTrainer(
140136
model=model_config.model_name_or_path,
141-
model_init_kwargs=model_kwargs,
142137
args=training_args,
143138
train_dataset=train_dataset,
144139
eval_dataset=eval_dataset,

0 commit comments

Comments
 (0)