-
Couldn't load subscription status.
- Fork 13
Wrenformer #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Wrenformer #44
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rename plot_scaled_errors() to scale_errors()
in collaboration with Rokas
…l matbench tasks modify examples/mat_bench/slurm_submit.py to run wrenformer
…/wren/data.py rename cry_ids -> material_ids
use longer but clearer variable names
* add class InMemoryDataLoader in new module aviary/data.py
refactor run_matbench_task() to work with it
* remove device kwarg from BaseModelClass, load tensors onto GPU externally
doing devicde IO inside epoch loop can lead to significant slow down and means doing the same work at every epoch instead of once
also remove WyckoffData class from wrenformer/data.py and improve slurm submit script header formatting
* rewrite print_walltime decorator to also work as context manager
ensure model checkpoints and tensorboard logs are always saved relative to project root by prefixing paths with ROOT
* refactor run_matbench_task() to do single fold so each fold can be a separate slurm job
* mv examples/mat_bench/run_{wrenformer=>matbench}.py
…of just JSON
simply write {dataset: {fold: preds}} dict as compressed JSON to disk
…ion tasks change typo in InMemoryDataLoader: default shuffle=True->False
…oader 1024->128 also fix key error in bench_dict[dataset_name][fold]
use it to run roostformer in run_matbench_task() if model name contains roost
more efficient thsn reloading all model preds and computinh afterwards
…ipython shell magic cmd
…merging results from separate slurm jobs run_matbench_task() drop arg benchmark_path: str replaced by timestamp: str
…mat_bench/utils.py
…y dependency (was imported only for softmax)
d_model can now be specified separately depending on dataset size
for more information, see https://pre-commit.ci
record number of trainable params in wandb config
…lassification test also adds batch_size kwarg to run_wrenformer()
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
CompRhys
reviewed
Jun 16, 2022
also rename some poorly named variables: - element_weights -> wyckoff_site_multiplicities in aviary/wren/data.py - aflow -> aflow_label_with_chemsys in aviary/wren/utils.py
remove global numpy random seed in aviary/data.py
arises due to torch.tensor(float("nan")) defaulting to CPU
ce30cb8 to
c0b2ba7
Compare
Merged
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new variant of the Wren model called Wrenformer, a rewrite that drops the custom self-attention in favor of builtin PyTorch
TransformerEncoder. It also prepares Matbench submissions for Roost and Wren (structure tasks only for Wren).The Wren rewrite as a transformer encoder was necessary to run robustly across Matbench tasks as the original Wren would run into out-of-memory errors during training and inference on materials with large numbers of Wyckoff positions (< 16 was a safe cutoff). This happened even on A100 GPUs with 80 GB RAM.
Initial performance testing (training for 100 epochs and with only 3 transformer layers) suggests Wrenformer slightly beats Wren across all Matbench tasks I recorded for both.
More sets of hyperparameters
Speed difference between Wren and Wrenformer
According to @CompRhys, Wren could run 500 epochs in 5.5 h on a P100 training on 120k samples of MP data (similar to the
matbench_mp_e_formdataset with 132k samples). Wrenformer only managed 207 epochs in 4h on the more powerful A100 training onmatbench_mp_e_form. However, to avoid out-of-memory issues, Rhys constrained Wren to only run on systems with <= 16 Wyckoff positions. The code below shows that this lightens the workload by a factor of about 7.5, likely explaining the apparent slowdown in Wrenformer.