Skip to content

Commit 62896ed

Browse files
authored
Merge pull request #47 from CompRhys/fix-wrenformer-scheduler
Fix Wrenformer scheduler
2 parents a58d6fc + 9248739 commit 62896ed

File tree

6 files changed

+77
-30
lines changed

6 files changed

+77
-30
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ repos:
1414
rev: 4.0.1
1515
hooks:
1616
- id: flake8
17-
args: [--ignore, 'W503,E203']
18-
# W503 black conflicts with "line break before operator" rule
19-
# E203 black conflicts with "whitespace before ':'" rule
2017

2118
- repo: https://github.com/asottile/pyupgrade
2219
rev: v2.31.1

aviary/wrenformer/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import time
33
from contextlib import contextmanager
4-
from typing import Generator
4+
from typing import Generator, Literal
55

66

77
def _int_keys(dct: dict) -> dict:
@@ -21,12 +21,20 @@ def recursive_dict_merge(d1: dict, d2: dict) -> dict:
2121
return d1
2222

2323

24-
def merge_json_on_disk(dct: dict, file_path: str) -> None:
24+
def merge_json_on_disk(
25+
dct: dict,
26+
file_path: str,
27+
on_non_serializable: Literal["annotate", "error"] = "annotate",
28+
) -> None:
2529
"""Merge a dict into a (possibly) existing JSON file.
2630
2731
Args:
2832
file_path (str): Path to JSON file. File will be created if not exist.
2933
dct (dict): Dictionary to merge into JSON file.
34+
on_non_serializable ('annotate' | 'error'): What to do with non-serializable values
35+
encountered in dct. 'annotate' will replace the offending object with a string
36+
indicating the type, e.g. '<not serializable: function>'. 'error' will raise
37+
'TypeError: Object of type function is not JSON serializable'. Defaults to 'annotate'.
3038
"""
3139
try:
3240
with open(file_path) as json_file:
@@ -36,8 +44,15 @@ def merge_json_on_disk(dct: dict, file_path: str) -> None:
3644
except (FileNotFoundError, json.decoder.JSONDecodeError): # file missing or empty
3745
pass
3846

47+
def non_serializable_handler(obj: object) -> str:
48+
# replace functions and classes in dct with string indicating a non-serializable type
49+
return f"<not serializable: {type(obj).__qualname__}>"
50+
3951
with open(file_path, "w") as file:
40-
json.dump(dct, file)
52+
default = (
53+
non_serializable_handler if on_non_serializable == "annotate" else None
54+
)
55+
json.dump(dct, file, default=default)
4156

4257

4358
@contextmanager

examples/mat_bench/slurm_submit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
folds = list(range(5))
1919
checkpoint = None # None | 'local' | 'wandb'
2020
lr = 3e-4
21-
model_name = f"wrenformer-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace("e-0", "e-")
21+
model_name = f"wrenformer-{lr=}-{epochs=}-{n_attn_layers=}"
2222

2323
if "roost" in model_name.lower():
2424
# deploy Roost on all tasks

examples/mp_wbm/slurm_submit.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# %% write Python submission file and sbatch it
1414
epochs = 300
1515
n_attn_layers = 3
16-
embedding_aggregations = ("mean", "std", "min", "max")
16+
embedding_aggregations = ("mean",)
17+
optimizer = "AdamW"
18+
lr = 3e-4
19+
scheduler = ("CosineAnnealingLR", {"T_max": epochs})
1720
n_folds = 1
1821
data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
1922
target = "e_form"
2023
task_type = "regression"
2124
checkpoint = None # None | 'local' | 'wandb'
22-
lr = 3e-4
23-
batch_size = 64
24-
model_name = f"wrenformer-robust-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace(
25-
"e-0", "e-"
26-
)
25+
batch_size = 128
26+
model_name = f"wrenformer-robust-{epochs=}"
27+
2728

2829
os.makedirs(log_dir := f"{MODULE_DIR}/job-logs", exist_ok=True)
2930
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M}"
@@ -52,7 +53,9 @@
5253
{epochs=},
5354
{n_attn_layers=},
5455
{checkpoint=},
56+
{optimizer=},
5557
learning_rate={lr},
58+
{scheduler=},
5659
{embedding_aggregations=},
5760
{batch_size=},
5861
)

examples/wrenformer.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def run_wrenformer(
5050
checkpoint: Literal["local", "wandb"] | None = None,
5151
swa_start=0.7,
5252
run_params: dict[str, Any] = None,
53+
optimizer: str | tuple[str, dict] = "AdamW",
54+
scheduler: str | tuple[str, dict] = "LambdaLR",
5355
learning_rate: float = 3e-4,
5456
batch_size: int = 128,
5557
warmup_steps: int = 10,
@@ -82,9 +84,16 @@ def run_wrenformer(
8284
```
8385
swa_start (float | None): When to start using stochastic weight averaging during training.
8486
Should be a float between 0 and 1. 0.7 means start SWA after 70% of epochs. Set to
85-
None to disable SWA. Defaults to 0.7.
87+
None to disable SWA. Defaults to 0.7. Proposed in https://arxiv.org/abs/1803.05407.
8688
run_params (dict[str, Any]): Additional parameters to merge into the run's dict of
8789
hyperparams. Will be logged to wandb. Can be anything really. Defaults to {}.
90+
optimizer (str | tuple[str, dict]): Name of a torch.optim.Optimizer class like 'Adam',
91+
'AdamW', 'SGD', etc. Can be a string or a string and dict with params to pass to the
92+
class. Defaults to 'AdamW'.
93+
scheduler (str | tuple[str, dict]): Name of a torch.optim.lr_scheduler class like
94+
'LambdaLR', 'StepLR', 'CosineAnnealingLR', etc. Defaults to 'LambdaLR'. Can be a string
95+
or a string and dict with params to pass to the class. E.g.
96+
('CosineAnnealingLR', {'T_max': n_epochs}).
8897
learning_rate (float): The optimizer's learning rate. Defaults to 3e-4.
8998
batch_size (int): The mini-batch size during training. Defaults to 128.
9099
warmup_steps (int): How many warmup steps the scheduler should do. Defaults to 10.
@@ -181,31 +190,53 @@ def run_wrenformer(
181190
embedding_aggregations=embedding_aggregations,
182191
)
183192
model.to(device)
184-
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
193+
if isinstance(optimizer, str):
194+
optimizer_name, optimizer_params = optimizer, None
195+
elif isinstance(optimizer, (tuple, list)):
196+
optimizer_name, optimizer_params = optimizer
197+
else:
198+
raise ValueError(f"Unknown {optimizer=}")
199+
optimizer_cls = getattr(torch.optim, optimizer_name)
200+
optimizer_instance = optimizer_cls(
201+
params=model.parameters(), lr=learning_rate, **(optimizer_params or {})
202+
)
185203

186204
# This lambda goes up linearly until warmup_steps, then follows a power law decay.
187205
# Acts as a prefactor to the learning rate, i.e. actual_lr = lr_lambda(epoch) *
188206
# learning_rate.
189-
scheduler = torch.optim.lr_scheduler.LambdaLR(
190-
optimizer,
191-
lambda epoch: min((epoch + 1) ** (-0.5), (epoch + 1) * warmup_steps ** (-1.5)),
192-
)
207+
if scheduler == "LambdaLR":
208+
scheduler_name, scheduler_params = "LambdaLR", {
209+
"lr_lambda": lambda epoch: min(
210+
(epoch + 1) ** (-0.5), (epoch + 1) * warmup_steps ** (-1.5)
211+
)
212+
}
213+
elif isinstance(scheduler, str):
214+
scheduler_name, scheduler_params = scheduler, None
215+
elif isinstance(scheduler, (tuple, list)):
216+
scheduler_name, scheduler_params = scheduler
217+
else:
218+
raise ValueError(f"Unknown {scheduler=}")
219+
scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_name)
220+
scheduler_instance = scheduler_cls(optimizer_instance, **(scheduler_params or {}))
193221

194222
if swa_start is not None:
195223
swa_model = AveragedModel(model)
196-
# scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
197-
swa_scheduler = SWALR(optimizer, swa_lr=0.01)
224+
swa_scheduler_instance = SWALR(optimizer_instance, swa_lr=0.01)
198225

199226
run_params = {
200227
"epochs": epochs,
228+
"optimizer": optimizer_name,
229+
"optimizer_params": optimizer_params,
201230
"learning_rate": learning_rate,
231+
"lr_scheduler": scheduler_name,
232+
"scheduler_params": scheduler_params,
202233
"batch_size": batch_size,
203234
"n_attn_layers": n_attn_layers,
204235
"target": target_col,
205236
"warmup_steps": warmup_steps,
206237
"robust": robust,
207238
"embedding_len": embedding_len,
208-
"losses": str(loss_dict),
239+
"losses": loss_dict,
209240
"training_samples": len(train_df),
210241
"test_samples": len(test_df),
211242
"trainable_params": model.num_params,
@@ -232,7 +263,7 @@ def run_wrenformer(
232263
train_metrics = model.evaluate(
233264
train_loader,
234265
loss_dict,
235-
optimizer,
266+
optimizer_instance,
236267
normalizer_dict,
237268
action="train",
238269
verbose=verbose,
@@ -250,10 +281,10 @@ def run_wrenformer(
250281

251282
if swa_start is not None and epoch > swa_start * epochs:
252283
swa_model.update_parameters(model)
253-
swa_scheduler.step()
284+
swa_scheduler_instance.step()
254285
else:
255-
scheduler.step()
256-
scheduler.step()
286+
scheduler_instance.step()
287+
257288
model.epoch += 1
258289

259290
if wandb_project:
@@ -293,9 +324,9 @@ def run_wrenformer(
293324
if checkpoint is not None:
294325
state_dict = {
295326
"model_state": inference_model.state_dict(),
296-
"optimizer_state": optimizer.state_dict(),
327+
"optimizer_state": optimizer_instance.state_dict(),
297328
"scheduler_state": (
298-
scheduler if swa_start is None else swa_scheduler
329+
scheduler_instance if swa_start is None else swa_scheduler_instance
299330
).state_dict(),
300331
"loss_dict": loss_dict,
301332
"epoch": epochs,

setup.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
[flake8]
22
max-line-length = 100
3-
# E203: whitespace before ':' preferred by black
4-
ignore = E203
3+
# E203: black conflicts with "whitespace before ':'" rule
4+
# W503 black conflicts with "line break before operator" rule
5+
ignore = E203,W503
56

67
[isort]
78
profile = black

0 commit comments

Comments
 (0)