Skip to content

Commit 59b8017

Browse files
authored
Add test module for package health (#46)
* add module tests/test_package.py with test_egg_sources() checks that we're correctly packaging data files * refactor test_egg_sources() to not use glob(root_dir=str) was only added in py3.10 * fix including JSON data files in the package (for real this time)
1 parent 40757c1 commit 59b8017

File tree

8 files changed

+62
-29
lines changed

8 files changed

+62
-29
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
2828
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
2929
pip install .[test]
30+
cat aviary.egg-info/SOURCES.txt
3031
3132
- name: Run Tests
3233
run: python -m pytest --capture=no --cov aviary

aviary/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from os.path import abspath, dirname
2+
3+
PKG_DIR = dirname(abspath(__file__))
4+
ROOT = dirname(PKG_DIR)

aviary/cgcnn/data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import ast
44
import functools
55
import json
6+
import os
67
from itertools import groupby
7-
from os.path import abspath, dirname, exists, join
88
from typing import Any, Sequence
99

1010
import numpy as np
@@ -14,6 +14,8 @@
1414
from torch import LongTensor, Tensor
1515
from torch.utils.data import Dataset
1616

17+
from aviary import PKG_DIR
18+
1719

1820
class CrystalGraphData(Dataset):
1921
"""Dataset class for the CGCNN structure model."""
@@ -60,12 +62,9 @@ def __init__(
6062
self.max_num_nbr = max_num_nbr
6163

6264
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
63-
elem_emb = join(
64-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
65-
)
66-
else:
67-
if not exists(elem_emb):
68-
raise AssertionError(f"{elem_emb} does not exist!")
65+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
66+
elif not os.path.exists(elem_emb):
67+
raise AssertionError(f"{elem_emb} does not exist!")
6968

7069
with open(elem_emb) as f:
7170
self.elem_features = json.load(f)

aviary/roost/data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import functools
44
import json
5-
from os.path import abspath, dirname, exists, join
5+
import os
66
from typing import Any, Sequence
77

88
import numpy as np
@@ -12,6 +12,8 @@
1212
from torch import LongTensor, Tensor
1313
from torch.utils.data import Dataset
1414

15+
from aviary import PKG_DIR
16+
1517

1618
class CompositionData(Dataset):
1719
"""Dataset class for the Roost composition model."""
@@ -46,12 +48,9 @@ def __init__(
4648
self.df = df
4749

4850
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
49-
elem_emb = join(
50-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
51-
)
52-
else:
53-
if not exists(elem_emb):
54-
raise AssertionError(f"{elem_emb} does not exist!")
51+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
52+
elif not os.path.exists(elem_emb):
53+
raise AssertionError(f"{elem_emb} does not exist!")
5554

5655
with open(elem_emb) as f:
5756
self.elem_features = json.load(f)

aviary/wren/data.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import functools
44
import json
5+
import os
56
import re
67
from itertools import groupby
7-
from os.path import abspath, dirname, exists, join
88
from typing import Any, Sequence
99

1010
import numpy as np
@@ -13,6 +13,7 @@
1313
from torch import LongTensor, Tensor
1414
from torch.utils.data import Dataset
1515

16+
from aviary import PKG_DIR
1617
from aviary.wren.utils import mult_dict, relab_dict
1718

1819

@@ -51,25 +52,19 @@ def __init__(
5152
self.df = df
5253

5354
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
54-
elem_emb = join(
55-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
56-
)
57-
else:
58-
if not exists(elem_emb):
59-
raise AssertionError(f"{elem_emb} does not exist!")
55+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
56+
elif not os.path.exists(elem_emb):
57+
raise AssertionError(f"{elem_emb} does not exist!")
6058

6159
with open(elem_emb) as f:
6260
self.elem_features = json.load(f)
6361

6462
self.elem_emb_len = len(list(self.elem_features.values())[0])
6563

6664
if sym_emb in ["bra-alg-off", "spg-alg-off"]:
67-
sym_emb = join(
68-
dirname(abspath(__file__)), f"../embeddings/wyckoff/{sym_emb}.json"
69-
)
70-
else:
71-
if not exists(sym_emb):
72-
raise AssertionError(f"{sym_emb} does not exist!")
65+
sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json"
66+
elif not os.path.exists(sym_emb):
67+
raise AssertionError(f"{sym_emb} does not exist!")
7368

7469
with open(sym_emb) as f:
7570
self.sym_features = json.load(f)

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
author="Rhys Goodall",
77
author_email="[email protected]",
88
url="https://github.com/CompRhys/aviary",
9-
description="A Collection of Machine Learning Models for Materials Discovery",
9+
description="A collection of machine learning models for materials discovery",
1010
long_description=open("README.md").read(),
1111
long_description_content_type="text/markdown",
1212
packages=find_packages(include=["aviary*"]),
1313
classifiers=[
1414
"Programming Language :: Python :: 3.7",
1515
"Programming Language :: Python :: 3.8",
1616
"Programming Language :: Python :: 3.9",
17+
"Programming Language :: Python :: 3.10",
18+
"Programming Language :: Python :: 3.11",
1719
"License :: OSI Approved :: MIT License",
1820
"Operating System :: OS Independent",
1921
],
@@ -23,7 +25,9 @@
2325
"Wyckoff positions",
2426
"Crystal Structure Prediction",
2527
],
26-
package_data={"": ["**/*.json"]}, # include all JSON files in the package
28+
# if any package at most 2 levels under the aviary namespace contains *.json files,
29+
# include them in the package
30+
package_data={"aviary": ["**/*.json", "**/**/*.json"]},
2731
python_requires=">=3.7",
2832
install_requires=[
2933
"scipy",

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from aviary.cgcnn.utils import get_cgcnn_input
88
from aviary.wren.utils import get_aflow_label_spglib
99

10+
__author__ = "Janosh Riebesell"
11+
__date__ = "2022-04-09"
12+
1013
torch.manual_seed(0) # ensure reproducible results (applies to all tests)
1114

1215

tests/test_package.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
from glob import glob
3+
4+
import pytest
5+
6+
from aviary import ROOT
7+
8+
package_sources_path = f"{ROOT}/aviary.egg-info/SOURCES.txt"
9+
10+
11+
__author__ = "Janosh Riebesell"
12+
__date__ = "2022-05-25"
13+
14+
15+
@pytest.mark.skipif(
16+
not os.path.exists(package_sources_path),
17+
reason="No aviary.egg-info/SOURCES.txt file, run pip install . to create it",
18+
)
19+
def test_egg_sources():
20+
"""Check we're correctly packaging all JSON files under aviary/ to prevent issues
21+
like https://github.com/CompRhys/aviary/pull/45.
22+
"""
23+
with open(package_sources_path) as file:
24+
sources = file.read()
25+
26+
for filepath in glob(f"{ROOT}/aviary/**/*.json", recursive=True):
27+
rel_path = filepath.split(f"{ROOT}/aviary/")[1]
28+
assert rel_path in sources

0 commit comments

Comments
 (0)