@@ -50,6 +50,8 @@ def run_wrenformer(
5050 checkpoint : Literal ["local" , "wandb" ] | None = None ,
5151 swa_start = 0.7 ,
5252 run_params : dict [str , Any ] = None ,
53+ optimizer : str | tuple [str , dict ] = "AdamW" ,
54+ scheduler : str | tuple [str , dict ] = "LambdaLR" ,
5355 learning_rate : float = 3e-4 ,
5456 batch_size : int = 128 ,
5557 warmup_steps : int = 10 ,
@@ -82,9 +84,16 @@ def run_wrenformer(
8284 ```
8385 swa_start (float | None): When to start using stochastic weight averaging during training.
8486 Should be a float between 0 and 1. 0.7 means start SWA after 70% of epochs. Set to
85- None to disable SWA. Defaults to 0.7.
87+ None to disable SWA. Defaults to 0.7. Proposed in https://arxiv.org/abs/1803.05407.
8688 run_params (dict[str, Any]): Additional parameters to merge into the run's dict of
8789 hyperparams. Will be logged to wandb. Can be anything really. Defaults to {}.
90+ optimizer (str | tuple[str, dict]): Name of a torch.optim.Optimizer class like 'Adam',
91+ 'AdamW', 'SGD', etc. Can be a string or a string and dict with params to pass to the
92+ class. Defaults to 'AdamW'.
93+ scheduler (str | tuple[str, dict]): Name of a torch.optim.lr_scheduler class like
94+ 'LambdaLR', 'StepLR', 'CosineAnnealingLR', etc. Defaults to 'LambdaLR'. Can be a string
95+ or a string and dict with params to pass to the class. E.g.
96+ ('CosineAnnealingLR', {'T_max': n_epochs}).
8897 learning_rate (float): The optimizer's learning rate. Defaults to 3e-4.
8998 batch_size (int): The mini-batch size during training. Defaults to 128.
9099 warmup_steps (int): How many warmup steps the scheduler should do. Defaults to 10.
@@ -181,31 +190,53 @@ def run_wrenformer(
181190 embedding_aggregations = embedding_aggregations ,
182191 )
183192 model .to (device )
184- optimizer = torch .optim .AdamW (params = model .parameters (), lr = learning_rate )
193+ if isinstance (optimizer , str ):
194+ optimizer_name , optimizer_params = optimizer , None
195+ elif isinstance (optimizer , (tuple , list )):
196+ optimizer_name , optimizer_params = optimizer
197+ else :
198+ raise ValueError (f"Unknown { optimizer = } " )
199+ optimizer_cls = getattr (torch .optim , optimizer_name )
200+ optimizer_instance = optimizer_cls (
201+ params = model .parameters (), lr = learning_rate , ** (optimizer_params or {})
202+ )
185203
186204 # This lambda goes up linearly until warmup_steps, then follows a power law decay.
187205 # Acts as a prefactor to the learning rate, i.e. actual_lr = lr_lambda(epoch) *
188206 # learning_rate.
189- scheduler = torch .optim .lr_scheduler .LambdaLR (
190- optimizer ,
191- lambda epoch : min ((epoch + 1 ) ** (- 0.5 ), (epoch + 1 ) * warmup_steps ** (- 1.5 )),
192- )
207+ if scheduler == "LambdaLR" :
208+ scheduler_name , scheduler_params = "LambdaLR" , {
209+ "lr_lambda" : lambda epoch : min (
210+ (epoch + 1 ) ** (- 0.5 ), (epoch + 1 ) * warmup_steps ** (- 1.5 )
211+ )
212+ }
213+ elif isinstance (scheduler , str ):
214+ scheduler_name , scheduler_params = scheduler , None
215+ elif isinstance (scheduler , (tuple , list )):
216+ scheduler_name , scheduler_params = scheduler
217+ else :
218+ raise ValueError (f"Unknown { scheduler = } " )
219+ scheduler_cls = getattr (torch .optim .lr_scheduler , scheduler_name )
220+ scheduler_instance = scheduler_cls (optimizer_instance , ** (scheduler_params or {}))
193221
194222 if swa_start is not None :
195223 swa_model = AveragedModel (model )
196- # scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
197- swa_scheduler = SWALR (optimizer , swa_lr = 0.01 )
224+ swa_scheduler_instance = SWALR (optimizer_instance , swa_lr = 0.01 )
198225
199226 run_params = {
200227 "epochs" : epochs ,
228+ "optimizer" : optimizer_name ,
229+ "optimizer_params" : optimizer_params ,
201230 "learning_rate" : learning_rate ,
231+ "lr_scheduler" : scheduler_name ,
232+ "scheduler_params" : scheduler_params ,
202233 "batch_size" : batch_size ,
203234 "n_attn_layers" : n_attn_layers ,
204235 "target" : target_col ,
205236 "warmup_steps" : warmup_steps ,
206237 "robust" : robust ,
207238 "embedding_len" : embedding_len ,
208- "losses" : str ( loss_dict ) ,
239+ "losses" : loss_dict ,
209240 "training_samples" : len (train_df ),
210241 "test_samples" : len (test_df ),
211242 "trainable_params" : model .num_params ,
@@ -232,7 +263,7 @@ def run_wrenformer(
232263 train_metrics = model .evaluate (
233264 train_loader ,
234265 loss_dict ,
235- optimizer ,
266+ optimizer_instance ,
236267 normalizer_dict ,
237268 action = "train" ,
238269 verbose = verbose ,
@@ -250,10 +281,10 @@ def run_wrenformer(
250281
251282 if swa_start is not None and epoch > swa_start * epochs :
252283 swa_model .update_parameters (model )
253- swa_scheduler .step ()
284+ swa_scheduler_instance .step ()
254285 else :
255- scheduler .step ()
256- scheduler . step ()
286+ scheduler_instance .step ()
287+
257288 model .epoch += 1
258289
259290 if wandb_project :
@@ -293,9 +324,9 @@ def run_wrenformer(
293324 if checkpoint is not None :
294325 state_dict = {
295326 "model_state" : inference_model .state_dict (),
296- "optimizer_state" : optimizer .state_dict (),
327+ "optimizer_state" : optimizer_instance .state_dict (),
297328 "scheduler_state" : (
298- scheduler if swa_start is None else swa_scheduler
329+ scheduler_instance if swa_start is None else swa_scheduler_instance
299330 ).state_dict (),
300331 "loss_dict" : loss_dict ,
301332 "epoch" : epochs ,
0 commit comments