Skip to content

Commit 31bd925

Browse files
committed
add kwargs optimizer and scheduler to run_wrenformer() with defaults 'AdamW' and 'LambdaLR'
1 parent a00b78b commit 31bd925

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

examples/mat_bench/slurm_submit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
folds = list(range(5))
1919
checkpoint = None # None | 'local' | 'wandb'
2020
lr = 3e-4
21-
model_name = f"wrenformer-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace("e-0", "e-")
21+
model_name = f"wrenformer-{lr=}-{epochs=}-{n_attn_layers=}"
2222

2323
if "roost" in model_name.lower():
2424
# deploy Roost on all tasks

examples/mp_wbm/slurm_submit.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# %% write Python submission file and sbatch it
1414
epochs = 300
1515
n_attn_layers = 3
16-
embedding_aggregations = ("mean", "std", "min", "max")
16+
embedding_aggregations = ("mean",)
17+
optimizer = "AdamW"
18+
lr = 3e-4
19+
scheduler = ("CosineAnnealingLR", {"T_max": epochs})
1720
n_folds = 1
1821
data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
1922
target = "e_form"
2023
task_type = "regression"
2124
checkpoint = None # None | 'local' | 'wandb'
22-
lr = 3e-4
23-
batch_size = 64
24-
model_name = f"wrenformer-robust-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace(
25-
"e-0", "e-"
26-
)
25+
batch_size = 128
26+
model_name = f"wrenformer-robust-{epochs=}"
27+
2728

2829
os.makedirs(log_dir := f"{MODULE_DIR}/job-logs", exist_ok=True)
2930
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M}"
@@ -52,7 +53,9 @@
5253
{epochs=},
5354
{n_attn_layers=},
5455
{checkpoint=},
56+
{optimizer=},
5557
learning_rate={lr},
58+
{scheduler=},
5659
{embedding_aggregations=},
5760
{batch_size=},
5861
)

examples/wrenformer.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def run_wrenformer(
5050
checkpoint: Literal["local", "wandb"] | None = None,
5151
swa_start=0.7,
5252
run_params: dict[str, Any] = None,
53+
optimizer: str | tuple[str, dict] = "AdamW",
54+
scheduler: str | tuple[str, dict] = "LambdaLR",
5355
learning_rate: float = 3e-4,
5456
batch_size: int = 128,
5557
warmup_steps: int = 10,
@@ -82,9 +84,16 @@ def run_wrenformer(
8284
```
8385
swa_start (float | None): When to start using stochastic weight averaging during training.
8486
Should be a float between 0 and 1. 0.7 means start SWA after 70% of epochs. Set to
85-
None to disable SWA. Defaults to 0.7.
87+
None to disable SWA. Defaults to 0.7. Proposed in https://arxiv.org/abs/1803.05407.
8688
run_params (dict[str, Any]): Additional parameters to merge into the run's dict of
8789
hyperparams. Will be logged to wandb. Can be anything really. Defaults to {}.
90+
optimizer (str | tuple[str, dict]): Name of a torch.optim.Optimizer class like 'Adam',
91+
'AdamW', 'SGD', etc. Can be a string or a string and dict with params to pass to the
92+
class. Defaults to 'AdamW'.
93+
scheduler (str | tuple[str, dict]): Name of a torch.optim.lr_scheduler class like
94+
'LambdaLR', 'StepLR', 'CosineAnnealingLR', etc. Defaults to 'LambdaLR'. Can be a string
95+
or a string and dict with params to pass to the class. E.g.
96+
('CosineAnnealingLR', {'T_max': n_epochs}).
8897
learning_rate (float): The optimizer's learning rate. Defaults to 3e-4.
8998
batch_size (int): The mini-batch size during training. Defaults to 128.
9099
warmup_steps (int): How many warmup steps the scheduler should do. Defaults to 10.
@@ -181,24 +190,46 @@ def run_wrenformer(
181190
embedding_aggregations=embedding_aggregations,
182191
)
183192
model.to(device)
184-
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
193+
if isinstance(optimizer, str):
194+
optimizer_name, optimizer_params = optimizer, None
195+
elif isinstance(optimizer, (tuple, list)):
196+
optimizer_name, optimizer_params = optimizer
197+
else:
198+
raise ValueError(f"Unknown {optimizer=}")
199+
optimizer_cls = getattr(torch.optim, optimizer_name)
200+
optimizer_instance = optimizer_cls(
201+
params=model.parameters(), lr=learning_rate, **(optimizer_params or {})
202+
)
185203

186204
# This lambda goes up linearly until warmup_steps, then follows a power law decay.
187205
# Acts as a prefactor to the learning rate, i.e. actual_lr = lr_lambda(epoch) *
188206
# learning_rate.
189-
scheduler = torch.optim.lr_scheduler.LambdaLR(
190-
optimizer,
191-
lambda epoch: min((epoch + 1) ** (-0.5), (epoch + 1) * warmup_steps ** (-1.5)),
192-
)
207+
if scheduler == "LambdaLR":
208+
scheduler_name, scheduler_params = "LambdaLR", {
209+
"lr_lambda": lambda epoch: min(
210+
(epoch + 1) ** (-0.5), (epoch + 1) * warmup_steps ** (-1.5)
211+
)
212+
}
213+
elif isinstance(scheduler, str):
214+
scheduler_name, scheduler_params = scheduler, None
215+
elif isinstance(scheduler, (tuple, list)):
216+
scheduler_name, scheduler_params = scheduler
217+
else:
218+
raise ValueError(f"Unknown {scheduler=}")
219+
scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_name)
220+
scheduler_instance = scheduler_cls(optimizer_instance, **(scheduler_params or {}))
193221

194222
if swa_start is not None:
195223
swa_model = AveragedModel(model)
196-
# scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
197-
swa_scheduler = SWALR(optimizer, swa_lr=0.01)
224+
swa_scheduler_instance = SWALR(optimizer_instance, swa_lr=0.01)
198225

199226
run_params = {
200227
"epochs": epochs,
228+
"optimizer": optimizer_name,
229+
"optimizer_params": optimizer_params,
201230
"learning_rate": learning_rate,
231+
"lr_scheduler": scheduler_name,
232+
"scheduler_params": scheduler_params,
202233
"batch_size": batch_size,
203234
"n_attn_layers": n_attn_layers,
204235
"target": target_col,
@@ -232,7 +263,7 @@ def run_wrenformer(
232263
train_metrics = model.evaluate(
233264
train_loader,
234265
loss_dict,
235-
optimizer,
266+
optimizer_instance,
236267
normalizer_dict,
237268
action="train",
238269
verbose=verbose,
@@ -250,9 +281,9 @@ def run_wrenformer(
250281

251282
if swa_start is not None and epoch > swa_start * epochs:
252283
swa_model.update_parameters(model)
253-
swa_scheduler.step()
284+
swa_scheduler_instance.step()
254285
else:
255-
scheduler.step()
286+
scheduler_instance.step()
256287

257288
model.epoch += 1
258289

@@ -293,9 +324,9 @@ def run_wrenformer(
293324
if checkpoint is not None:
294325
state_dict = {
295326
"model_state": inference_model.state_dict(),
296-
"optimizer_state": optimizer.state_dict(),
327+
"optimizer_state": optimizer_instance.state_dict(),
297328
"scheduler_state": (
298-
scheduler if swa_start is None else swa_scheduler
329+
scheduler_instance if swa_start is None else swa_scheduler_instance
299330
).state_dict(),
300331
"loss_dict": loss_dict,
301332
"epoch": epochs,

0 commit comments

Comments
 (0)