Skip to content

Commit 1b121a0

Browse files
authored
Merge pull request #44 from CompRhys/wrenformer
Wrenformer
2 parents 59b8017 + c0b2ba7 commit 1b121a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2875
-655
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@ jobs:
1818
- name: Set up Python
1919
uses: actions/setup-python@v3
2020
with:
21-
python-version: 3.7
21+
python-version: 3.8
2222
cache: pip
2323
cache-dependency-path: setup.py
2424

2525
- name: Install dependencies
2626
run: |
27-
pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
28-
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
27+
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
28+
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cpu.html
2929
pip install .[test]
30-
cat aviary.egg-info/SOURCES.txt
3130
3231
- name: Run Tests
3332
run: python -m pytest --capture=no --cov aviary

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,10 @@ datasets/
2929
pds/
3030
manuscript/
3131
voro-thesis/
32+
33+
# MatBench run artifacts like model preds, checkpoints, metrics and slurm job logs
34+
examples/mat_bench/model_preds
35+
examples/mat_bench/model_scores
36+
examples/mat_bench/checkpoints
37+
job-logs*
38+
wandb

.pre-commit-config.yaml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ repos:
4040
hooks:
4141
- id: codespell
4242
exclude_types: [json]
43-
args: [--ignore-words-list, 'hist,ba']
4443

4544
- repo: https://github.com/psf/black
4645
rev: 22.3.0
@@ -51,7 +50,7 @@ repos:
5150
rev: v0.942
5251
hooks:
5352
- id: mypy
54-
exclude: (tests|examples)
53+
exclude: (tests|examples)/
5554

5655
- repo: https://github.com/myint/autoflake
5756
rev: v1.4
@@ -68,11 +67,9 @@ repos:
6867
rev: 6.1.1
6968
hooks:
7069
- id: pydocstyle
71-
# D100: Missing docstring in public module
72-
# D104: Missing docstring in public package
73-
# D105: Missing docstring in magic method
74-
# D107: Missing docstring in __init__
75-
# D205: 1 blank line required between summary line and description
76-
# D415: First line should end with ., ? or !
77-
args: [--convention=google, '--add-ignore=D100,D104,D105,D107,D205,D415']
78-
exclude: (tests|examples)
70+
exclude: (tests|examples)/
71+
72+
- repo: https://github.com/janosh/format-ipy-cells
73+
rev: v0.1.10
74+
hooks:
75+
- id: format-ipy-cells

aviary/cgcnn/data.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def __getitem__(self, idx: int):
168168
- list[str | int]: identifiers like material_id, composition
169169
"""
170170
# NOTE sites must be given in fractional coordinates
171-
df_idx = self.df.iloc[idx]
172-
crystal = df_idx["Structure_obj"]
173-
cry_ids = df_idx[self.identifiers]
171+
row = self.df.iloc[idx]
172+
crystal = row["Structure_obj"]
173+
material_ids = row[self.identifiers]
174174

175175
# atom features for disordered sites
176176
site_atoms = [atom.species.as_dict() for atom in crystal]
@@ -187,11 +187,13 @@ def __getitem__(self, idx: int):
187187
self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)
188188

189189
if not len(self_idx):
190-
raise AssertionError(f"All atoms in {cry_ids} are isolated")
190+
raise AssertionError(f"All atoms in {material_ids} are isolated")
191191
if not len(nbr_idx):
192-
raise AssertionError(f"This should not be triggered but was for {cry_ids}")
192+
raise AssertionError(
193+
f"This should not be triggered but was for {material_ids}"
194+
)
193195
if set(self_idx) != set(range(crystal.num_sites)):
194-
raise AssertionError(f"At least one atom in {cry_ids} is isolated")
196+
raise AssertionError(f"At least one atom in {material_ids} is isolated")
195197

196198
nbr_dist = self.gdf.expand(nbr_dist)
197199

@@ -203,14 +205,14 @@ def __getitem__(self, idx: int):
203205
targets: list[Tensor | LongTensor] = []
204206
for target, task_type in self.task_dict.items():
205207
if task_type == "regression":
206-
targets.append(Tensor([df_idx[target]]))
208+
targets.append(Tensor([row[target]]))
207209
elif task_type == "classification":
208-
targets.append(LongTensor([df_idx[target]]))
210+
targets.append(LongTensor([row[target]]))
209211

210212
return (
211213
(atom_fea_t, nbr_dist_t, self_idx_t, nbr_idx_t),
212214
targets,
213-
*cry_ids,
215+
*material_ids,
214216
)
215217

216218

aviary/cgcnn/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_scatter import scatter_add, scatter_mean
88

99
from aviary.core import BaseModelClass
10-
from aviary.segments import SimpleNetwork
10+
from aviary.networks import SimpleNetwork
1111

1212

1313
class CrystalGraphConvNet(BaseModelClass):
@@ -36,7 +36,10 @@ def __init__(
3636
"""Initialize CrystalGraphConvNet.
3737
3838
Args:
39-
robust (bool): Whether to estimate standard deviation for use in a robust loss function
39+
robust (bool): If True, the number of model outputs is doubled. 2nd output for each
40+
target will be an estimate for the aleatoric uncertainty (uncertainty inherent to
41+
the sample) which can be used with a robust loss function to attenuate the weighting
42+
of uncertain samples.
4043
n_targets (list[int]): Number of targets to train on
4144
elem_emb_len (int): Number of atom features in the input.
4245
nbr_fea_len (int): Number of bond features.

0 commit comments

Comments
 (0)