|
2 | 2 |
|
3 | 3 | import functools |
4 | 4 | import json |
| 5 | +import os |
5 | 6 | import re |
6 | 7 | from itertools import groupby |
7 | | -from os.path import abspath, dirname, exists, join |
8 | 8 | from typing import Any, Sequence |
9 | 9 |
|
10 | 10 | import numpy as np |
|
13 | 13 | from torch import LongTensor, Tensor |
14 | 14 | from torch.utils.data import Dataset |
15 | 15 |
|
| 16 | +from aviary import PKG_DIR |
16 | 17 | from aviary.wren.utils import mult_dict, relab_dict |
17 | 18 |
|
18 | 19 |
|
@@ -51,25 +52,19 @@ def __init__( |
51 | 52 | self.df = df |
52 | 53 |
|
53 | 54 | if elem_emb in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: |
54 | | - elem_emb = join( |
55 | | - dirname(abspath(__file__)), f"../embeddings/element/{elem_emb}.json" |
56 | | - ) |
57 | | - else: |
58 | | - if not exists(elem_emb): |
59 | | - raise AssertionError(f"{elem_emb} does not exist!") |
| 55 | + elem_emb = f"{PKG_DIR}/embeddings/element/{elem_emb}.json" |
| 56 | + elif not os.path.exists(elem_emb): |
| 57 | + raise AssertionError(f"{elem_emb} does not exist!") |
60 | 58 |
|
61 | 59 | with open(elem_emb) as f: |
62 | 60 | self.elem_features = json.load(f) |
63 | 61 |
|
64 | 62 | self.elem_emb_len = len(list(self.elem_features.values())[0]) |
65 | 63 |
|
66 | 64 | if sym_emb in ["bra-alg-off", "spg-alg-off"]: |
67 | | - sym_emb = join( |
68 | | - dirname(abspath(__file__)), f"../embeddings/wyckoff/{sym_emb}.json" |
69 | | - ) |
70 | | - else: |
71 | | - if not exists(sym_emb): |
72 | | - raise AssertionError(f"{sym_emb} does not exist!") |
| 65 | + sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json" |
| 66 | + elif not os.path.exists(sym_emb): |
| 67 | + raise AssertionError(f"{sym_emb} does not exist!") |
73 | 68 |
|
74 | 69 | with open(sym_emb) as f: |
75 | 70 | self.sym_features = json.load(f) |
|
0 commit comments