Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
pip install .[test]
cat aviary.egg-info/SOURCES.txt

- name: Run Tests
run: python -m pytest --capture=no --cov aviary
4 changes: 4 additions & 0 deletions aviary/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from os.path import abspath, dirname

PKG_DIR = dirname(abspath(__file__))
ROOT = dirname(PKG_DIR)
13 changes: 6 additions & 7 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import ast
import functools
import json
import os
from itertools import groupby
from os.path import abspath, dirname, exists, join
from typing import Any, Sequence

import numpy as np
Expand All @@ -14,6 +14,8 @@
from torch import LongTensor, Tensor
from torch.utils.data import Dataset

from aviary import PKG_DIR


class CrystalGraphData(Dataset):
"""Dataset class for the CGCNN structure model."""
Expand Down Expand Up @@ -60,12 +62,9 @@ def __init__(
self.max_num_nbr = max_num_nbr

if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
elem_emb = join(
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
)
else:
if not exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
elif not os.path.exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")

with open(elem_emb) as f:
self.elem_features = json.load(f)
Expand Down
13 changes: 6 additions & 7 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import json
from os.path import abspath, dirname, exists, join
import os
from typing import Any, Sequence

import numpy as np
Expand All @@ -12,6 +12,8 @@
from torch import LongTensor, Tensor
from torch.utils.data import Dataset

from aviary import PKG_DIR


class CompositionData(Dataset):
"""Dataset class for the Roost composition model."""
Expand Down Expand Up @@ -46,12 +48,9 @@ def __init__(
self.df = df

if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
elem_emb = join(
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
)
else:
if not exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
elif not os.path.exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")

with open(elem_emb) as f:
self.elem_features = json.load(f)
Expand Down
21 changes: 8 additions & 13 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import functools
import json
import os
import re
from itertools import groupby
from os.path import abspath, dirname, exists, join
from typing import Any, Sequence

import numpy as np
Expand All @@ -13,6 +13,7 @@
from torch import LongTensor, Tensor
from torch.utils.data import Dataset

from aviary import PKG_DIR
from aviary.wren.utils import mult_dict, relab_dict


Expand Down Expand Up @@ -51,25 +52,19 @@ def __init__(
self.df = df

if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
elem_emb = join(
dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json"
)
else:
if not exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")
elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json"
elif not os.path.exists(elem_emb):
raise AssertionError(f"{elem_emb} does not exist!")

with open(elem_emb) as f:
self.elem_features = json.load(f)

self.elem_emb_len = len(list(self.elem_features.values())[0])

if sym_emb in ["bra-alg-off", "spg-alg-off"]:
sym_emb = join(
dirname(abspath(__file__)), f"../embeddings/wyckoff/{sym_emb}.json"
)
else:
if not exists(sym_emb):
raise AssertionError(f"{sym_emb} does not exist!")
sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json"
elif not os.path.exists(sym_emb):
raise AssertionError(f"{sym_emb} does not exist!")

with open(sym_emb) as f:
self.sym_features = json.load(f)
Expand Down
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
author="Rhys Goodall",
author_email="[email protected]",
url="https://github.com/CompRhys/aviary",
description="A Collection of Machine Learning Models for Materials Discovery",
description="A collection of machine learning models for materials discovery",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
packages=find_packages(include=["aviary*"]),
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
Expand All @@ -23,7 +25,9 @@
"Wyckoff positions",
"Crystal Structure Prediction",
],
package_data={"": ["**/*.json"]}, # include all JSON files in the package
# if any package at most 2 levels under the aviary namespace contains *.json files,
# include them in the package
package_data={"aviary": ["**/*.json", "**/**/*.json"]},
python_requires=">=3.7",
install_requires=[
"scipy",
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from aviary.cgcnn.utils import get_cgcnn_input
from aviary.wren.utils import get_aflow_label_spglib

__author__ = "Janosh Riebesell"
__date__ = "2022-04-09"

torch.manual_seed(0) # ensure reproducible results (applies to all tests)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from glob import glob

import pytest

from aviary import ROOT

package_sources_path = f"{ROOT}/aviary.egg-info/SOURCES.txt"


__author__ = "Janosh Riebesell"
__date__ = "2022-05-25"


@pytest.mark.skipif(
not os.path.exists(package_sources_path),
reason="No aviary.egg-info/SOURCES.txt file, run pip install . to create it",
)
def test_egg_sources():
"""Check we're correctly packaging all JSON files under aviary/ to prevent issues
like https://github.com/CompRhys/aviary/pull/45.
"""
with open(package_sources_path) as file:
sources = file.read()

for filepath in glob(f"{ROOT}/aviary/**/*.json", recursive=True):
rel_path = filepath.split(f"{ROOT}/aviary/")[1]
assert rel_path in sources