|
94 | 94 | console = Console() |
95 | 95 |
|
96 | 96 | ################ |
97 | | - # Model & Tokenizer |
| 97 | + # Model init kwargs & Tokenizer |
98 | 98 | ################ |
99 | | - torch_dtype = ( |
100 | | - model_config.torch_dtype |
101 | | - if model_config.torch_dtype in ["auto", None] |
102 | | - else getattr(torch, model_config.torch_dtype) |
103 | | - ) |
104 | 99 | quantization_config = get_quantization_config(model_config) |
105 | 100 | model_kwargs = dict( |
106 | 101 | revision=model_config.model_revision, |
107 | 102 | trust_remote_code=model_config.trust_remote_code, |
108 | 103 | attn_implementation=model_config.attn_implementation, |
109 | | - torch_dtype=torch_dtype, |
| 104 | + torch_dtype=model_config.torch_dtype, |
110 | 105 | use_cache=False if training_args.gradient_checkpointing else True, |
111 | 106 | device_map=get_kbit_device_map() if quantization_config is not None else None, |
112 | 107 | quantization_config=quantization_config, |
113 | 108 | ) |
| 109 | + training_args.model_init_kwargs = model_kwargs |
114 | 110 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) |
115 | 111 | tokenizer.pad_token = tokenizer.eos_token |
116 | 112 |
|
|
138 | 134 | with init_context: |
139 | 135 | trainer = SFTTrainer( |
140 | 136 | model=model_config.model_name_or_path, |
141 | | - model_init_kwargs=model_kwargs, |
142 | 137 | args=training_args, |
143 | 138 | train_dataset=train_dataset, |
144 | 139 | eval_dataset=eval_dataset, |
|
0 commit comments