Skip to content

Conversation

@janosh
Copy link
Collaborator

@janosh janosh commented May 12, 2022

This PR adds a new variant of the Wren model called Wrenformer, a rewrite that drops the custom self-attention in favor of builtin PyTorch TransformerEncoder. It also prepares Matbench submissions for Roost and Wren (structure tasks only for Wren).

The Wren rewrite as a transformer encoder was necessary to run robustly across Matbench tasks as the original Wren would run into out-of-memory errors during training and inference on materials with large numbers of Wyckoff positions (< 16 was a safe cutoff). This happened even on A100 GPUs with 80 GB RAM.

Initial performance testing (training for 100 epochs and with only 3 transformer layers) suggests Wrenformer slightly beats Wren across all Matbench tasks I recorded for both.

2022-04-29-matbench-scaled-error-heatmap

More sets of hyperparameters

2022-06-06-matbench-scaled-errors-heatmap

Speed difference between Wren and Wrenformer

According to @CompRhys, Wren could run 500 epochs in 5.5 h on a P100 training on 120k samples of MP data (similar to the matbench_mp_e_form dataset with 132k samples). Wrenformer only managed 207 epochs in 4h on the more powerful A100 training on matbench_mp_e_form. However, to avoid out-of-memory issues, Rhys constrained Wren to only run on systems with <= 16 Wyckoff positions. The code below shows that this lightens the workload by a factor of about 7.5, likely explaining the apparent slowdown in Wrenformer.

import pandas as pd
from aviary.wren.utils import count_wyks
from examples.mat_bench import DATA_PATHS

df = pd.read_json(DATA_PATHS["matbench_mp_e_form"])

df["n_wyckoff"] = df.wyckoff.map(count_wyks)


sum_wyckoffs_sqr = (df.n_wyckoff**2).sum()
sum_wyckoffs_lte_16_sqr = (df.query("n_wyckoff <= 16").n_wyckoff ** 2).sum()
print(f"{sum_wyckoffs_sqr=}")
print(f"{sum_wyckoffs_lte_16_sqr=}")
print(f"{sum_wyckoffs_sqr/sum_wyckoffs_lte_16_sqr=:.3}")
# prints 7.45, so Wrenformer has to do 7.45x more work, explaining the about 2x slow down
# on a more powerful GPU (Nvidia A100 vs Wren on a P100)

janosh added 28 commits April 27, 2022 14:10
rename plot_scaled_errors() to scale_errors()
…l matbench tasks

modify examples/mat_bench/slurm_submit.py to run wrenformer
…/wren/data.py

rename cry_ids -> material_ids
* add class InMemoryDataLoader in new module aviary/data.py

refactor run_matbench_task() to work with it

* remove device kwarg from BaseModelClass, load tensors onto GPU externally

doing devicde IO inside epoch loop can lead to significant slow down and means doing the same work at every epoch instead of once
also remove WyckoffData class from wrenformer/data.py and improve slurm submit script header formatting

* rewrite print_walltime decorator to also work as context manager

ensure model checkpoints and tensorboard logs are always saved relative to project root by prefixing paths with ROOT

* refactor run_matbench_task() to do single fold so each fold can be a separate slurm job

* mv examples/mat_bench/run_{wrenformer=>matbench}.py
…of just JSON

simply write {dataset: {fold: preds}} dict as compressed JSON to disk
…ion tasks

change typo in InMemoryDataLoader: default shuffle=True->False
…oader 1024->128

also fix key error in bench_dict[dataset_name][fold]
use it to run roostformer in run_matbench_task() if model name contains roost
more efficient thsn reloading all model preds and computinh afterwards
…merging results from separate slurm jobs

run_matbench_task() drop arg benchmark_path: str replaced by timestamp: str
…y dependency (was imported only for softmax)
d_model can now be specified separately depending on dataset size
@janosh janosh added the enhancement New feature or request label May 12, 2022
janosh added 4 commits June 16, 2022 20:44
also rename some poorly named variables:
- element_weights -> wyckoff_site_multiplicities in aviary/wren/data.py
- aflow -> aflow_label_with_chemsys in aviary/wren/utils.py
remove global numpy random seed in aviary/data.py
arises due to torch.tensor(float("nan")) defaulting to CPU
@janosh janosh force-pushed the wrenformer branch 2 times, most recently from ce30cb8 to c0b2ba7 Compare June 16, 2022 20:56
@janosh janosh merged commit 1b121a0 into main Jun 16, 2022
@janosh janosh deleted the wrenformer branch June 16, 2022 20:57
@janosh janosh mentioned this pull request Jun 18, 2022
CompRhys pushed a commit that referenced this pull request Jun 30, 2022
CompRhys pushed a commit that referenced this pull request Jun 30, 2022
CompRhys pushed a commit that referenced this pull request Jul 1, 2022
CompRhys pushed a commit that referenced this pull request Jul 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants