Skip to content

Commit 8b84866

Browse files
committed
add module tests/test_package.py with test_egg_sources()
checks that we're correctly packaging data files
1 parent 40757c1 commit 8b84866

File tree

6 files changed

+51
-28
lines changed

6 files changed

+51
-28
lines changed

aviary/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from os.path import abspath, dirname
2+
3+
PKG_DIR = dirname(abspath(__file__))
4+
ROOT = dirname(PKG_DIR)

aviary/cgcnn/data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import ast
44
import functools
55
import json
6+
import os
67
from itertools import groupby
7-
from os.path import abspath, dirname, exists, join
88
from typing import Any, Sequence
99

1010
import numpy as np
@@ -14,6 +14,8 @@
1414
from torch import LongTensor, Tensor
1515
from torch.utils.data import Dataset
1616

17+
from aviary import PKG_DIR
18+
1719

1820
class CrystalGraphData(Dataset):
1921
"""Dataset class for the CGCNN structure model."""
@@ -60,12 +62,9 @@ def __init__(
6062
self.max_num_nbr = max_num_nbr
6163

6264
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
63-
elem_emb = join(
64-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
65-
)
66-
else:
67-
if not exists(elem_emb):
68-
raise AssertionError(f"{elem_emb} does not exist!")
65+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
66+
elif not os.path.exists(elem_emb):
67+
raise AssertionError(f"{elem_emb} does not exist!")
6968

7069
with open(elem_emb) as f:
7170
self.elem_features = json.load(f)

aviary/roost/data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import functools
44
import json
5-
from os.path import abspath, dirname, exists, join
5+
import os
66
from typing import Any, Sequence
77

88
import numpy as np
@@ -12,6 +12,8 @@
1212
from torch import LongTensor, Tensor
1313
from torch.utils.data import Dataset
1414

15+
from aviary import PKG_DIR
16+
1517

1618
class CompositionData(Dataset):
1719
"""Dataset class for the Roost composition model."""
@@ -46,12 +48,9 @@ def __init__(
4648
self.df = df
4749

4850
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
49-
elem_emb = join(
50-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
51-
)
52-
else:
53-
if not exists(elem_emb):
54-
raise AssertionError(f"{elem_emb} does not exist!")
51+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
52+
elif not os.path.exists(elem_emb):
53+
raise AssertionError(f"{elem_emb} does not exist!")
5554

5655
with open(elem_emb) as f:
5756
self.elem_features = json.load(f)

aviary/wren/data.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import functools
44
import json
5+
import os
56
import re
67
from itertools import groupby
7-
from os.path import abspath, dirname, exists, join
88
from typing import Any, Sequence
99

1010
import numpy as np
@@ -13,6 +13,7 @@
1313
from torch import LongTensor, Tensor
1414
from torch.utils.data import Dataset
1515

16+
from aviary import PKG_DIR
1617
from aviary.wren.utils import mult_dict, relab_dict
1718

1819

@@ -51,25 +52,19 @@ def __init__(
5152
self.df = df
5253

5354
if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
54-
elem_emb = join(
55-
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
56-
)
57-
else:
58-
if not exists(elem_emb):
59-
raise AssertionError(f"{elem_emb} does not exist!")
55+
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
56+
elif not os.path.exists(elem_emb):
57+
raise AssertionError(f"{elem_emb} does not exist!")
6058

6159
with open(elem_emb) as f:
6260
self.elem_features = json.load(f)
6361

6462
self.elem_emb_len = len(list(self.elem_features.values())[0])
6563

6664
if sym_emb in ["bra-alg-off", "spg-alg-off"]:
67-
sym_emb = join(
68-
dirname(abspath(__file__)), f"../embeddings/wyckoff/{sym_emb}.json"
69-
)
70-
else:
71-
if not exists(sym_emb):
72-
raise AssertionError(f"{sym_emb} does not exist!")
65+
sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json"
66+
elif not os.path.exists(sym_emb):
67+
raise AssertionError(f"{sym_emb} does not exist!")
7368

7469
with open(sym_emb) as f:
7570
self.sym_features = json.load(f)

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
author="Rhys Goodall",
77
author_email="[email protected]",
88
url="https://github.com/CompRhys/aviary",
9-
description="A Collection of Machine Learning Models for Materials Discovery",
9+
description="A collection of machine learning models for materials discovery",
1010
long_description=open("README.md").read(),
1111
long_description_content_type="text/markdown",
1212
packages=find_packages(include=["aviary*"]),
1313
classifiers=[
1414
"Programming Language :: Python :: 3.7",
1515
"Programming Language :: Python :: 3.8",
1616
"Programming Language :: Python :: 3.9",
17+
"Programming Language :: Python :: 3.10",
18+
"Programming Language :: Python :: 3.11",
1719
"License :: OSI Approved :: MIT License",
1820
"Operating System :: OS Independent",
1921
],

tests/test_package.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
from glob import glob
3+
4+
import pytest
5+
6+
from aviary import ROOT
7+
8+
package_sources_path = f"{ROOT}/aviary.egg-info/SOURCES.txt"
9+
10+
11+
@pytest.mark.skipif(
12+
not os.path.exists(package_sources_path),
13+
reason="No aviary.egg-info/SOURCES.txt file, run pip install . to create it",
14+
)
15+
def test_egg_sources():
16+
with open(package_sources_path) as file:
17+
sources = file.read()
18+
19+
json_files_under_aviary = glob(
20+
"**/*.json", recursive=True, root_dir=f"{ROOT}/aviary"
21+
)
22+
23+
for json_file in json_files_under_aviary:
24+
assert json_file in sources, f"{json_file} not found in SOURCES.txt"

0 commit comments

Comments
 (0)