Skip to content

Commit 5b9a49b

Browse files
committed
make embedding_aggregations an explicit kwarg to wrenformer
1 parent 65ff974 commit 5b9a49b

File tree

5 files changed

+64
-33
lines changed

5 files changed

+64
-33
lines changed

aviary/core.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -324,20 +324,19 @@ def evaluate(
324324
mixed_loss.backward()
325325
optimizer.step()
326326

327-
epoch_averaged_metrics = {
328-
target: {
329-
metric_key: np.array(values).mean().squeeze().round(4)
330-
for metric_key, values in dct.items()
327+
avrg_metrics: dict[str, dict[str, float]] = {}
328+
for target, per_batch_metrics in epoch_metrics.items():
329+
avrg_metrics[target] = {
330+
metric_key: (np.array(values).mean().squeeze().round(4))
331+
for metric_key, values in per_batch_metrics.items()
331332
}
332-
for target, dct in epoch_metrics.items()
333-
}
334-
# take sqrt at the end to get correct epoch RMSE
335-
# per-batch averaged RMSE != RMSE of full epoch since (sqrt(a + b) != sqrt(a) + sqrt(b))
336-
for metrics_for_target in epoch_averaged_metrics.values():
337-
if "MSE" in metrics_for_target:
338-
metrics_for_target["RMSE"] = metrics_for_target.pop("MSE") ** 0.5
339-
340-
return epoch_averaged_metrics
333+
# take sqrt at the end to get correct epoch RMSE as per-batch averaged RMSE
334+
# != RMSE of full epoch since (sqrt(a + b) != sqrt(a) + sqrt(b))
335+
avrg_mse = avrg_metrics[target].pop("MSE")
336+
if avrg_mse:
337+
avrg_metrics[target]["RMSE"] = (avrg_mse**0.5).round(4)
338+
339+
return avrg_metrics
341340

342341
@torch.no_grad()
343342
def predict(

aviary/wrenformer/model.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Callable, Sequence
4+
35
import torch
46
import torch.nn as nn
57
import torch.nn.functional as F
@@ -35,6 +37,7 @@ def __init__(
3537
trunk_hidden: list[int] = [1024, 512],
3638
out_hidden: list[int] = [256, 128, 64],
3739
robust: bool = False,
40+
embedding_aggregations: Sequence[str] = ("mean",),
3841
**kwargs,
3942
) -> None:
4043
"""Initialize the Wrenformer model.
@@ -57,6 +60,9 @@ def __init__(
5760
target will be an estimate for the aleatoric uncertainty (uncertainty inherent to
5861
the sample) which can be used with a robust loss function to attenuate the weighting
5962
of uncertain samples.
63+
embedding_aggregations (list[str]): Aggregations to apply to the learned embedding
64+
returned by the transformer encoder before passing into the ResidualNetwork. One or
65+
more of ['mean', 'std', 'sum', 'min', 'max']. Defaults to ['mean'].
6066
"""
6167
super().__init__(robust=robust, **kwargs)
6268

@@ -73,9 +79,10 @@ def __init__(
7379
if self.robust:
7480
n_targets = [2 * n for n in n_targets]
7581

76-
n_aggregators = 2 # number of embedding aggregation functions
82+
self.embedding_aggregations = embedding_aggregations
7783
self.trunk_nn = ResidualNetwork(
78-
input_dim=n_aggregators * d_model,
84+
# len(embedding_aggregations) = number of catted tensors in aggregated_embeddings below
85+
input_dim=len(embedding_aggregations) * d_model,
7986
output_dim=out_hidden[0],
8087
hidden_layer_dims=trunk_hidden,
8188
)
@@ -123,18 +130,25 @@ def forward( # type: ignore
123130
# into a single vector Wyckoff embedding
124131
# careful to ignore padded values when taking the mean
125132
inv_mask: torch.BoolTensor = ~mask[..., None]
126-
# sum_agg = (embeddings * inv_mask).sum(dim=1)
127-
128-
# # replace padded values with +/-inf to exclude them from min/max
129-
# min_agg, _ = torch.where(inv_mask, embeddings, float("inf")).min(dim=1)
130-
# max_agg, _ = torch.where(inv_mask, embeddings, float("-inf")).max(dim=1)
131-
mean_agg = masked_mean(embeddings, inv_mask, dim=1)
132-
std_agg = masked_std(embeddings, inv_mask, dim=1)
133133

134-
# Sum+Std+Min+Max+Mean: we call this S2M3 aggregation
135-
aggregated_embeddings = torch.cat([mean_agg, std_agg], dim=1)
134+
aggregation_funcs = [aggregators[key] for key in self.embedding_aggregations]
135+
aggregated_embeddings = torch.cat(
136+
[func(embeddings, inv_mask, 1) for func in aggregation_funcs], dim=1
137+
)
136138

137139
# main body of the feed-forward NN jointly used by all multitask objectives
138140
predictions = F.relu(self.trunk_nn(aggregated_embeddings))
139141

140142
return tuple(output_nn(predictions) for output_nn in self.output_nns)
143+
144+
145+
# using all at once we call this S2M3 aggregation
146+
aggregators: dict[str, Callable[[Tensor, BoolTensor, int], Tensor]] = {
147+
"mean": masked_mean,
148+
"sum": lambda x, mask, dim: (x * mask).sum(dim=dim),
149+
"std": masked_std,
150+
# replace padded values with +/-inf to make sure min()/max() ignore them
151+
"min": lambda x, mask, dim: torch.where(mask, x, float("inf")).min(dim=dim)[0],
152+
# 1st ret val = max, 2nd ret val = max indices
153+
"max": lambda x, mask, dim: torch.where(mask, x, float("-inf")).max(dim=dim)[0],
154+
}

aviary/wrenformer/run.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Any, Literal
4+
from typing import Any, Literal, Sequence
55

66
import numpy as np
77
import pandas as pd
@@ -49,6 +49,7 @@ def run_wrenformer(
4949
run_params: dict[str, Any] = None,
5050
learning_rate: float = 3e-4,
5151
warmup_steps: int = 10,
52+
embedding_aggregations: Sequence[str] = ("mean",),
5253
) -> tuple[dict[str, float], dict[str, Any], pd.DataFrame]:
5354
"""Run a single matbench task.
5455
@@ -81,6 +82,9 @@ def run_wrenformer(
8182
hyperparams. Will be logged to wandb. Can be anything really. Defaults to {}.
8283
learning_rate (float): The optimizer's learning rate. Defaults to 3e-4.
8384
warmup_steps (int): How many warmup steps the scheduler should do. Defaults to 10.
85+
embedding_aggregations (list[str]): Aggregations to apply to the learned embedding returned
86+
by the transformer encoder before passing into the ResidualNetwork. One or more of
87+
['mean', 'std', 'sum', 'min', 'max']. Defaults to ['mean'].
8488
8589
Raises:
8690
ValueError: On unknown dataset_name or invalid checkpoint.
@@ -107,7 +111,7 @@ def run_wrenformer(
107111
)
108112
assert "wyckoff" in df, err_msg
109113
with print_walltime(
110-
start_desc=f"{label} Generating Wyckoff embeddings", newline=False
114+
start_desc=f"Generating Wyckoff embeddings for {label}", newline=False
111115
):
112116
df["features"] = df.wyckoff.map(wyckoff_embedding_from_aflow_str)
113117
elif "roost" in run_name.lower():
@@ -161,6 +165,7 @@ def run_wrenformer(
161165
task_dict={target_col: task_type}, # e.g. {'exfoliation_en': 'regression'}
162166
n_attn_layers=n_attn_layers,
163167
robust=robust,
168+
embedding_aggregations=embedding_aggregations,
164169
)
165170
model.to(device)
166171
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
@@ -192,6 +197,7 @@ def run_wrenformer(
192197
"trainable_params": model.num_params,
193198
"swa_start": swa_start,
194199
"timestamp": timestamp,
200+
"embedding_aggregations": embedding_aggregations,
195201
**(run_params or {}),
196202
}
197203

@@ -228,6 +234,12 @@ def run_wrenformer(
228234
wandb.log({"training": train_metrics, "validation": val_metrics})
229235

230236
# get test set predictions
237+
if swa_start is not None:
238+
n_swa_epochs = int((1 - swa_start) * epochs)
239+
print(
240+
f"Using SWA model with weights averaged over {n_swa_epochs} epochs ({swa_start = })"
241+
)
242+
231243
inference_model = swa_model if swa_start is not None else model
232244
inference_model.eval()
233245

examples/mat_bench/slurm_submit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
# %% write Python submission file and sbatch it
1515
epochs = 300
1616
n_attn_layers = 6
17-
model_name = f"wrenformer-robust-s2m3-aggregation-{epochs=}-{n_attn_layers=}"
17+
embedding_aggregations = ("mean",)
1818
folds = list(range(5))
19-
checkpoint = "wandb" # None | 'local' | 'wandb'
20-
learning_rate = 1e-3
19+
checkpoint = None # None | 'local' | 'wandb'
20+
lr = 3e-4
21+
model_name = f"wrenformer-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace("e-0", "e-")
2122

2223
if "roost" in model_name.lower():
2324
# deploy Roost on all tasks
@@ -26,6 +27,8 @@
2627
# deploy Wren on structure tasks only
2728
datasets = [k for k, v in mbv01_metadata.items() if v.input_type == "structure"]
2829

30+
datasets = ["matbench_mp_e_form"]
31+
2932
os.makedirs(log_dir := f"{MODULE_DIR}/job-logs", exist_ok=True)
3033
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M}"
3134

@@ -56,7 +59,8 @@
5659
{epochs=},
5760
{n_attn_layers=},
5861
{checkpoint=},
59-
{learning_rate=},
62+
learning_rate={lr},
63+
{embedding_aggregations=},
6064
)
6165
"""
6266

examples/mp_wbm/slurm_submit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# %% write Python submission file and sbatch it
1414
epochs = 300
1515
n_attn_layers = 6
16-
model_name = f"wrenformer-robust-s2m3-aggregation-{epochs=}-{n_attn_layers=}"
16+
embedding_aggregations = ("mean",)
1717
fold = 0
1818
n_folds = 1
1919
data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
2020
target = "e_form"
2121
task_type = "regression"
2222
checkpoint = "wandb" # None | 'local' | 'wandb'
23-
learning_rate = 1e-3
23+
lr = 3e-4
24+
model_name = f"wrenformer-{lr=:.0e}-{epochs=}-{n_attn_layers=}".replace("e-0", "e-")
2425

2526
os.makedirs(log_dir := f"{MODULE_DIR}/job-logs", exist_ok=True)
2627
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M}"
@@ -51,7 +52,8 @@
5152
{epochs=},
5253
{n_attn_layers=},
5354
{checkpoint=},
54-
{learning_rate=},
55+
learning_rate={lr},
56+
{embedding_aggregations=},
5557
)
5658
"""
5759

0 commit comments

Comments
 (0)