-
Notifications
You must be signed in to change notification settings - Fork 1
Add function to calculate soap descriptors #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
9caa400
871fc36
b201c89
7a868f0
5c7ae9e
e5714e3
e279981
2925e30
baac508
89fbe9c
99ad7cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import argparse | ||
| import sys | ||
| import yaml | ||
|
|
||
|
|
||
| def read_file(path): | ||
| with open(path) as f: | ||
| return yaml.safe_load(f) | ||
|
|
||
|
|
||
| def parse_args(argv=None): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--base', dest='base', help='base environment.yml file') | ||
| parser.add_argument('--add', dest='add', help='addon environment.yml file') | ||
| return parser.parse_args(argv) | ||
|
|
||
|
|
||
| def merge_dependencies(env_base, env_add): | ||
| base_dict = {f.split()[0]: f for f in env_base} | ||
| add_dict = {f.split()[0]: f for f in env_add} | ||
| for k,v in add_dict.items(): | ||
| if k not in base_dict.keys(): | ||
| base_dict[k] = v | ||
| return list(base_dict.values()) | ||
|
|
||
|
|
||
| def merge_channels(env_base, env_add): | ||
| for c in env_add: | ||
| if c not in env_base: | ||
| env_base.append(c) | ||
| return env_base | ||
|
|
||
|
|
||
| def merge_env(env_base, env_add): | ||
| return { | ||
| "channels": merge_channels( | ||
| env_base=env_base['channels'], | ||
| env_add=env_add['channels'] | ||
| ), | ||
| 'dependencies': merge_dependencies( | ||
| env_base=env_base['dependencies'], | ||
| env_add=env_add['dependencies'] | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| arguments = parse_args(argv=None) | ||
| yaml.dump( | ||
| merge_env( | ||
| env_base=read_file(arguments.base), | ||
| env_add=read_file(arguments.add) | ||
| ), | ||
| sys.stdout, | ||
| indent=2, | ||
| default_flow_style=False | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| channels: | ||
| - conda-forge | ||
| dependencies: | ||
| - dscribe =2.1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| def calculate_soap_descriptor_per_atom( | ||
| structure, | ||
| r_cut=None, | ||
| n_max=None, | ||
| l_max=None, | ||
| sigma=1.0, | ||
| rbf="gto", | ||
| weighting=None, | ||
| average="off", | ||
| compression={"mode": "off", "species_weighting": None}, | ||
| species=None, | ||
| periodic=True, | ||
| sparse=False, | ||
| dtype="float64", | ||
| centers=None, | ||
| n_jobs=1, | ||
| only_physical_cores=False, | ||
| verbose=False, | ||
| ): | ||
| from dscribe.descriptors import SOAP | ||
|
|
||
| if species is None: | ||
| species = list(set(structure.get_chemical_symbols())) | ||
| periodic_soap = SOAP( | ||
| r_cut=r_cut, | ||
| n_max=n_max, | ||
| l_max=l_max, | ||
| sigma=sigma, | ||
| rbf=rbf, | ||
| weighting=weighting, | ||
| average=average, | ||
| compression=compression, | ||
| species=species, | ||
| periodic=periodic, | ||
| sparse=sparse, | ||
| dtype=dtype, | ||
| ) | ||
| return periodic_soap.create( | ||
| system=structure, | ||
| centers=centers, | ||
| n_jobs=n_jobs, | ||
| only_physical_cores=only_physical_cores, | ||
| verbose=verbose, | ||
| ) | ||
|
Comment on lines
+38
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer More explicit and clearer to someone who is trying to figure out what the function does.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would actually integrate better with the workflow stuff too. There, we scrape the output channel label from whatever's after the return value -- giving it a nice variable like that keeps things extra tidy. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # coding: utf-8 | ||
| # Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department | ||
| # Distributed under the terms of "New BSD License", see the LICENSE file. | ||
|
|
||
| import unittest | ||
| from ase.build import bulk | ||
| import numpy as np | ||
| import structuretoolkit as stk | ||
|
|
||
| try: | ||
| import dscribe | ||
|
|
||
| skip_dscribe_test = False | ||
| except ImportError: | ||
| skip_dscribe_test = True | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| skip_dscribe_test, "dscribe is not installed, so the dscribe tests are skipped." | ||
| ) | ||
| class Testdscribe(unittest.TestCase): | ||
| def test_calc_soap_descriptor_per_atom(self): | ||
| structure = bulk('Cu', 'fcc', a=3.6, cubic=True) | ||
| soap = stk.analyse.calculate_soap_descriptor_per_atom(structure=structure, r_cut=6.0, n_max=8, l_max=6) | ||
| self.assertEqual(soap.shape, (4, 252)) | ||
| self.assertTrue(np.isclose(soap.sum(), 39450.03009, atol=1.e-5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd just go with
stk.analyse.soap_descriptor_per_atomas I'm not convincedcalculate_really adds any information