Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
b95437a
start work on roost + wren matbench submission
janosh Apr 20, 2022
f9d4146
run_matbench_task() fix reading out classification predictions
janosh Apr 25, 2022
a048224
add examples/matbench/{plotting_functions,make_plots}.py
janosh Apr 25, 2022
5d5ee7e
add scaled_error_heatmap() to matbench/plotting_functions.py
janosh Apr 25, 2022
a3d28a5
mv examples/{matbench,mat_bench} to avoid shadowing matbench package …
janosh Apr 27, 2022
f645ab1
initial working version of Wren as a transformer
janosh Apr 27, 2022
1c4ceed
add examples/mat_bench/run_wrenformer.py for running wrenformer on al…
janosh Apr 28, 2022
349e2ef
wrenformer fix node aggregation: exclude padded sequence values
janosh Apr 28, 2022
4c89dd1
fix pytorch error from model and data on different devices
janosh Apr 28, 2022
5496d4a
run_wrenformer.py only create benchmark_dir if not empty string
janosh Apr 30, 2022
6f81d75
drop parse_aflow() from aviary/wrenformer/data.py, import from aviary…
janosh Apr 30, 2022
34472f5
rename new Wren variant to Wrenformer, fix pydocstyle doc string errors
janosh Apr 30, 2022
7a2d9a0
fix run_matbench_task() saving models to wrong hard-coded checkpoint …
janosh May 1, 2022
fccbdcf
Add `InMemoryDataLoader` (#43)
janosh May 3, 2022
df65c76
run_matbench_task() ditch writing MatbenchBenchmark to disk in favor …
janosh May 3, 2022
3b0aa17
fix run_matbench_task() use correct number of outputs for classificat…
janosh May 4, 2022
280a029
fix run_matbench_task() oom error from too large batch size in test_l…
janosh May 4, 2022
c165b04
add get_composition_embedding() in aviary/wrenformer/data.py
janosh May 4, 2022
131bbef
run_matbench_task() add kwarg n_transformer_layers
janosh May 5, 2022
b0b1958
run_matbench_task() save model scores directly to JSON
janosh May 9, 2022
d276a7b
slurm_submit.py use subprocess.run() to submit slurm jobs instead of …
janosh May 9, 2022
d88a0f0
add examples/mat_bench/utils.py with open_json() context manager for …
janosh May 10, 2022
762e9b6
refactor data loading of model errors in examples/mat_bench/make_plot…
janosh May 11, 2022
abd5c23
move print_walltime context manager from aviary/utils.py to examples/…
janosh May 11, 2022
fca4644
add custom np softmax() and one_hot() implementations and remove scip…
janosh May 11, 2022
d7be55f
wrenformer add linear layer to project embedding dim to d_model
janosh May 11, 2022
76625a7
fix BaseModelClass initial epoch to 0 (was 1)
janosh May 11, 2022
36d7807
fix tests after removal of device kwarg from BaseModelClass
janosh May 12, 2022
1daf90e
fix py37 not supporting unenclosed iterable unpacking
janosh May 12, 2022
17483bb
fix CI not supporting future type annotations
janosh May 12, 2022
5763f10
add get_metrics() in aviary/core.py
janosh May 12, 2022
c5f705b
add count_distinct_wyckoff_letters() in aviary/wren/utils.py
janosh May 12, 2022
7d2c322
drop matbench.data_ops.score_array in favor of aviary.utils.get_metrics
janosh May 12, 2022
ccd40fb
rename parse_aflow() => parse_aflow_wyckoff_str()
janosh May 13, 2022
6a7698e
add log_wandb=True kwarg to run_matbench_task()
janosh May 13, 2022
c730d3a
fix Wrenformer not averaging correctly over equivalent Wyckoff labell…
janosh May 14, 2022
457ffc9
refactor test_get_isopointal_proto() to use @pytest.mark.parametrize
janosh May 15, 2022
10f0df7
add test_get_aflow_label_aflow()
janosh May 15, 2022
3da2eb1
fix dict KeyError when trying to merge results from multiple slurm jo…
janosh May 15, 2022
7066510
refactor BaseModelClass.best_val_scores update
janosh May 16, 2022
4d1303e
refactor BaseModelClass.predict() and rename train/val_{generator=>lo…
janosh May 16, 2022
ebaded0
fix wandb init-start-error and fix data shape in plotly scatter plot …
janosh May 16, 2022
6e60e3e
fix get_aflow_label_aflow() element sort order mismatch between aflow…
janosh May 16, 2022
a269c31
delete plotly_identity_scatter(), rename merge_json() -> merge_json_o…
janosh May 16, 2022
9fc93d7
rename criterion_dict -> loss_dict everywhere
janosh May 16, 2022
2a486e1
add annotate_fig() in examples/mat_bench/plotting_functions.py
janosh May 17, 2022
1945268
add stochastic weight averaging to run_matbench_task()
janosh May 17, 2022
d370c1e
add 4-fold wrenformer embedding aggregation
janosh May 17, 2022
1c6a08e
rename Acc->Accuracy in BaseModelClass.{fit,evaluate}
janosh May 18, 2022
8adfa62
add kwarg errors=raise | annotate | ignore to get_aflow_label_aflow()…
janosh May 18, 2022
b012f78
rename n_transformer_layers to trafo_layers
janosh May 18, 2022
bf148ec
add examples/mat_bench/compare_spglib_aflow_wyckoff_labels.ipynb
janosh May 18, 2022
2222cf8
refactor dicts in make plots from 1st->2nd level = dataset->model_nam…
janosh May 18, 2022
2b5770c
convert examples/mat_bench/make_plots.{py=>ipynb} to notebook
janosh May 18, 2022
e001f6c
rename trafo_layers to n_attn_layers, n_attention_heads to n_attn_heads
janosh May 18, 2022
4fe9235
add tests for np_softmax() and np_one_hot() in new tests/test_core.py
janosh May 19, 2022
4778791
fix min/max aggregation in wrenformer not ignoring padded values
janosh May 20, 2022
fa2c772
add sankey plots of spacegroup distros to compare_spglib_aflow_wyckof…
janosh May 22, 2022
6f4236c
refactor get_aflow_label_aflow() to not require pymatgen.io.vasp.Poscar
janosh May 22, 2022
1a34d1e
rm examples/mat_bench/wrenformer.py
janosh May 22, 2022
b7e7910
add 5th embedding aggregation std
janosh May 22, 2022
de2c24f
add tests for masked_mean() and masked_std() in tests/test_core.py
janosh May 22, 2022
a9a44e9
bump pip install torch + torch.scatter to 1.11 in CI
janosh May 23, 2022
68eabf6
move class SimpleNetwork + ResidualNetwork to new module aviary/nn.py
janosh May 23, 2022
a1fd151
more descriptive variable names in MessageLayer.forward()
janosh May 23, 2022
997353f
fix RuntimeError from tensors on different devices in min/max aggrega…
janosh May 23, 2022
e279117
matbench error_heatmap() add 'dense scaled error' col and sort by that
janosh May 23, 2022
bb8cb95
run_matbench_task() add kwarg checkpoint: None | 'local' | 'wandb' = …
janosh May 25, 2022
1f469a8
try wrenformer with mean+std aggregation and no SWA
janosh May 28, 2022
165bb7f
get_aflow_label_spglib() add kwarg errors=raise | annotate | ignore
janosh Jun 6, 2022
74defe8
run_matbench_task() only login to wandb if not already logged in
janosh Jun 6, 2022
29a632f
run_matbench_task() fix wandb.run.summary recording
janosh Jun 6, 2022
9bcbcdf
run_matbench_task() raise better error if trying to run Wrenformer on…
janosh Jun 6, 2022
6b0e373
include dataset sizes in error_heatmap() and plot_leaderboard() axis …
janosh Jun 6, 2022
1c365a5
convert examples/mat_bench/make_plots.ipynb back to .py file due to h…
janosh Jun 6, 2022
e4e52fe
fix TypeError: Descriptors cannot not be created directly.
janosh Jun 6, 2022
4fc7d33
move codespell + pydocstyle commit hook args to setup.cfg
janosh Jun 8, 2022
94ac5db
explain file purpose in examples/mat_bench/readme.md
janosh Jun 8, 2022
c7b1b90
mv aviary/{nn,networks}.py
janosh Jun 9, 2022
43c4b1a
address Rhys PR review: rename loss_dict->loss_name_dict, _loss_dict-…
janosh Jun 9, 2022
9159c69
fix 19 of 22 remaining mypy errors
janosh Jun 9, 2022
e26301a
build embedding paths in {roost,wren,cgcnn}/data.py from aviary.PKG_DIR
janosh Jun 9, 2022
d48452e
move examples/mat_bench/{featurize_matbench,save_matbench_aflow_label…
janosh Jun 9, 2022
bf6f1e9
use longer more descriptive variable names in BaseModelClass.evaluate()
janosh Jun 11, 2022
5c9f21b
print_walltime context manager can now print before and after
janosh Jun 11, 2022
0ed731e
expand robust kwarg doc string
janosh Jun 11, 2022
95f2ad1
fix BaseModelClass.evaluate() preds-targets shape mismatch in regress…
janosh Jun 12, 2022
20b5a8d
run_matbench_task() chunk model output into preds and aleat_std in ro…
janosh Jun 12, 2022
74216f1
fix tests from another preds-targets shape mismatch
janosh Jun 12, 2022
7560d98
ci-fix: don't check complexity
CompRhys Jun 13, 2022
3581914
rename run_matbench_task() to run_wrenformer() and refactor to suppor…
janosh Jun 13, 2022
9533f22
mv examples/mat_bench/utils.py aviary/wrenformer/utils.py
janosh Jun 13, 2022
04b02d7
add package examples/mp_wbm with modules run_wrenformer.py, slurm_sub…
janosh Jun 15, 2022
8f83842
Merge remote-tracking branch 'origin/main' into wrenformer
janosh Jun 15, 2022
3daee8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2022
9daa82a
torch.save checkpoint swa_model if using SWA
janosh Jun 15, 2022
6fa5400
add kwargs learning_rate and warmup_steps to run_wrenformer()
janosh Jun 15, 2022
65ff974
include test_df in run_wrenformer() return values
janosh Jun 15, 2022
f6b32d4
make embedding_aggregations an explicit kwarg to wrenformer
janosh Jun 16, 2022
5a9a651
add module test_wrenformer.py with non-robust regression and robust c…
janosh Jun 16, 2022
735beec
fix py37 CI by not using walrus operator
janosh Jun 16, 2022
d092832
bump CI python version from 3.7 to 3.8
janosh Jun 16, 2022
19352b5
mv aviary/wrenformer/run.py examples/wrenformer.py
janosh Jun 16, 2022
4122b9a
guard against wandb not installed in examples/wrenformer.py
janosh Jun 16, 2022
c0b2ba7
fix masked_mean() error expected x and y to be on same device
janosh Jun 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: 3.7
python-version: 3.8
cache: pip
cache-dependency-path: setup.py

- name: Install dependencies
run: |
pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cpu.html
pip install .[test]
cat aviary.egg-info/SOURCES.txt

- name: Run Tests
run: python -m pytest --capture=no --cov aviary
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ datasets/
pds/
manuscript/
voro-thesis/

# MatBench run artifacts like model preds, checkpoints, metrics and slurm job logs
examples/mat_bench/model_preds
examples/mat_bench/model_scores
examples/mat_bench/checkpoints
job-logs*
wandb
17 changes: 7 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ repos:
hooks:
- id: codespell
exclude_types: [json]
args: [--ignore-words-list, 'hist,ba']

- repo: https://github.com/psf/black
rev: 22.3.0
Expand All @@ -51,7 +50,7 @@ repos:
rev: v0.942
hooks:
- id: mypy
exclude: (tests|examples)
exclude: (tests|examples)/

- repo: https://github.com/myint/autoflake
rev: v1.4
Expand All @@ -68,11 +67,9 @@ repos:
rev: 6.1.1
hooks:
- id: pydocstyle
# D100: Missing docstring in public module
# D104: Missing docstring in public package
# D105: Missing docstring in magic method
# D107: Missing docstring in __init__
# D205: 1 blank line required between summary line and description
# D415: First line should end with ., ? or !
args: [--convention=google, '--add-ignore=D100,D104,D105,D107,D205,D415']
exclude: (tests|examples)
exclude: (tests|examples)/

- repo: https://github.com/janosh/format-ipy-cells
rev: v0.1.10
hooks:
- id: format-ipy-cells
20 changes: 11 additions & 9 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def __getitem__(self, idx: int):
- list[str | int]: identifiers like material_id, composition
"""
# NOTE sites must be given in fractional coordinates
df_idx = self.df.iloc[idx]
crystal = df_idx["Structure_obj"]
cry_ids = df_idx[self.identifiers]
row = self.df.iloc[idx]
crystal = row["Structure_obj"]
material_ids = row[self.identifiers]

# atom features for disordered sites
site_atoms = [atom.species.as_dict() for atom in crystal]
Expand All @@ -187,11 +187,13 @@ def __getitem__(self, idx: int):
self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)

if not len(self_idx):
raise AssertionError(f"All atoms in {cry_ids} are isolated")
raise AssertionError(f"All atoms in {material_ids} are isolated")
if not len(nbr_idx):
raise AssertionError(f"This should not be triggered but was for {cry_ids}")
raise AssertionError(
f"This should not be triggered but was for {material_ids}"
)
if set(self_idx) != set(range(crystal.num_sites)):
raise AssertionError(f"At least one atom in {cry_ids} is isolated")
raise AssertionError(f"At least one atom in {material_ids} is isolated")

nbr_dist = self.gdf.expand(nbr_dist)

Expand All @@ -203,14 +205,14 @@ def __getitem__(self, idx: int):
targets: list[Tensor | LongTensor] = []
for target, task_type in self.task_dict.items():
if task_type == "regression":
targets.append(Tensor([df_idx[target]]))
targets.append(Tensor([row[target]]))
elif task_type == "classification":
targets.append(LongTensor([df_idx[target]]))
targets.append(LongTensor([row[target]]))

return (
(atom_fea_t, nbr_dist_t, self_idx_t, nbr_idx_t),
targets,
*cry_ids,
*material_ids,
)


Expand Down
7 changes: 5 additions & 2 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_scatter import scatter_add, scatter_mean

from aviary.core import BaseModelClass
from aviary.segments import SimpleNetwork
from aviary.networks import SimpleNetwork


class CrystalGraphConvNet(BaseModelClass):
Expand Down Expand Up @@ -36,7 +36,10 @@ def __init__(
"""Initialize CrystalGraphConvNet.

Args:
robust (bool): Whether to estimate standard deviation for use in a robust loss function
robust (bool): If True, the number of model outputs is doubled. 2nd output for each
target will be an estimate for the aleatoric uncertainty (uncertainty inherent to
the sample) which can be used with a robust loss function to attenuate the weighting
of uncertain samples.
n_targets (list[int]): Number of targets to train on
elem_emb_len (int): Number of atom features in the input.
nbr_fea_len (int): Number of bond features.
Expand Down
Loading