Skip to content

Commit 36cebdf

Browse files
maarten-icolivhoenen
authored andcommitted
Refactor iterators used in ids2nc and nc2ids
1 parent 6d95d00 commit 36cebdf

File tree

3 files changed

+108
-79
lines changed

3 files changed

+108
-79
lines changed

imas/backends/netcdf/ids2nc.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
11
# This file is part of IMAS-Python.
22
# You should have received the IMAS-Python LICENSE file with this project.
3-
"""NetCDF IO support for IMAS-Python. Requires [netcdf] extra dependencies.
4-
"""
5-
6-
from typing import Iterator, Tuple
3+
"""NetCDF IO support for IMAS-Python. Requires [netcdf] extra dependencies."""
74

85
import netCDF4
96
import numpy
107
from packaging import version
118

129
from imas.backends.netcdf.nc_metadata import NCMetadata
10+
from imas.backends.netcdf.iterators import indexed_tree_iter
1311
from imas.exception import InvalidNetCDFEntry
14-
from imas.ids_base import IDSBase
1512
from imas.ids_data_type import IDSDataType
1613
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
17-
from imas.ids_struct_array import IDSStructArray
18-
from imas.ids_structure import IDSStructure
1914
from imas.ids_toplevel import IDSToplevel
2015

2116
default_fillvals = {
@@ -33,26 +28,6 @@
3328
SHAPE_DTYPE = numpy.int32
3429

3530

36-
def nc_tree_iter(
37-
node: IDSStructure, aos_index: Tuple[int, ...] = ()
38-
) -> Iterator[Tuple[Tuple[int, ...], IDSBase]]:
39-
"""Tree iterator that tracks indices of all ancestor array of structures.
40-
41-
Args:
42-
node: IDS node to iterate over
43-
44-
Yields:
45-
(aos_index, node) for all filled nodes.
46-
"""
47-
for child in node.iter_nonempty_():
48-
yield (aos_index, child)
49-
if isinstance(child, IDSStructArray):
50-
for i in range(len(child)):
51-
yield from nc_tree_iter(child[i], aos_index + (i,))
52-
elif isinstance(child, IDSStructure):
53-
yield from nc_tree_iter(child, aos_index)
54-
55-
5631
class IDS2NC:
5732
"""Class responsible for storing an IDS to a NetCDF file."""
5833

@@ -105,7 +80,7 @@ def collect_filled_data(self) -> None:
10580
dimension_size = {}
10681
get_dimensions = self.ncmeta.get_dimensions
10782

108-
for aos_index, node in nc_tree_iter(self.ids):
83+
for aos_index, node in indexed_tree_iter(self.ids):
10984
path = node.metadata.path_string
11085
filled_data[path][aos_index] = node
11186
ndim = node.metadata.ndim

imas/backends/netcdf/iterators.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import Iterator, List, Optional, Tuple
2+
3+
from imas.ids_base import IDSBase
4+
from imas.ids_data_type import IDSDataType
5+
from imas.ids_metadata import IDSMetadata
6+
from imas.ids_struct_array import IDSStructArray
7+
from imas.ids_structure import IDSStructure
8+
from imas.ids_toplevel import IDSToplevel
9+
10+
11+
def _split_on_aos(metadata: IDSMetadata):
12+
"""Split paths per IDS."""
13+
paths = []
14+
curpath = metadata.name
15+
16+
item = metadata
17+
while item._parent.data_type is not None:
18+
item = item._parent
19+
if item.data_type is IDSDataType.STRUCT_ARRAY:
20+
paths.append(curpath)
21+
curpath = item.name
22+
else:
23+
curpath = f"{item.name}/{curpath}"
24+
paths.append(curpath)
25+
return paths[::-1]
26+
27+
28+
IndexedNode = Tuple[Tuple[int, ...], IDSBase]
29+
30+
31+
def indexed_tree_iter(
32+
ids: IDSToplevel, metadata: Optional[IDSMetadata] = None
33+
) -> Iterator[IndexedNode]:
34+
"""Tree iterator that tracks indices of all ancestor array of structures.
35+
36+
Args:
37+
ids: IDS top level element to iterate over
38+
metadata: Iterate over all nodes inside the IDS at the metadata object.
39+
If ``None``, all filled items in the IDS are iterated over.
40+
41+
Yields:
42+
(aos_indices, node) for all filled nodes.
43+
44+
Example:
45+
>>> ids = imas.IDSFactory().new("core_profiles")
46+
>>> ids.profiles_1d.resize(2)
47+
>>> ids.profiles_1d[0].time = 1.0
48+
>>> ids.profiles_1d[1].t_i_average = [1.0]
49+
>>> list(indexed_tree_iter(ids))
50+
[
51+
((), <IDSStructArray (IDS:core_profiles, profiles_1d with 2 items)>),
52+
((0,), <IDSFloat0D (IDS:core_profiles, profiles_1d[0]/time, FLT_0D)>),
53+
((1,), <IDSNumericArray (IDS:core_profiles, profiles_1d[1]/t_i_average, FLT_1D)>)
54+
]
55+
>>> list(indexed_tree_iter(ids, ids.metadata["profiles_1d/time"]))
56+
[
57+
((0,), <IDSFloat0D (IDS:core_profiles, profiles_1d[0]/time, FLT_0D)>),
58+
((1,), <IDSFloat0D (IDS:core_profiles, profiles_1d[1]/time, empty FLT_0D)>)
59+
]
60+
""" # noqa: E501
61+
if metadata is None:
62+
# Iterate over all filled nodes in the IDS
63+
yield from _full_tree_iter(ids, ())
64+
65+
else:
66+
paths = _split_on_aos(metadata)
67+
if len(paths) == 1:
68+
yield (), ids[paths[0]]
69+
else:
70+
yield from _tree_iter(ids, paths, ())
71+
72+
73+
def _tree_iter(
74+
structure: IDSStructure, paths: List[str], curindex: Tuple[int, ...]
75+
) -> Iterator[IndexedNode]:
76+
aos_path, *paths = paths
77+
aos = structure[aos_path]
78+
79+
if len(paths) == 1:
80+
path = paths[0]
81+
for i, node in enumerate(aos):
82+
yield curindex + (i,), node[path]
83+
84+
else:
85+
for i, node in enumerate(aos):
86+
yield from _tree_iter(node, paths, curindex + (i,))
87+
88+
89+
def _full_tree_iter(
90+
node: IDSStructure, cur_index: Tuple[int, ...]
91+
) -> Iterator[IndexedNode]:
92+
for child in node.iter_nonempty_():
93+
yield (cur_index, child)
94+
if isinstance(child, IDSStructArray):
95+
for i in range(len(child)):
96+
yield from _full_tree_iter(child[i], cur_index + (i,))
97+
elif isinstance(child, IDSStructure):
98+
yield from _full_tree_iter(child, cur_index)

imas/backends/netcdf/nc2ids.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
import logging
22
import os
3-
from typing import Iterator, List, Optional, Tuple
3+
from typing import Optional
44

55
import netCDF4
66
import numpy as np
77

88
from imas.backends.netcdf import ids2nc
99
from imas.backends.netcdf.nc_metadata import NCMetadata
10+
from imas.backends.netcdf.iterators import indexed_tree_iter
1011
from imas.exception import InvalidNetCDFEntry
11-
from imas.ids_base import IDSBase
1212
from imas.ids_convert import NBCPathMap
1313
from imas.ids_data_type import IDSDataType
1414
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
1515
from imas.ids_metadata import IDSMetadata
16-
from imas.ids_structure import IDSStructure
1716
from imas.ids_toplevel import IDSToplevel
1817

1918
logger = logging.getLogger(__name__)
@@ -26,49 +25,6 @@ def variable_error(var, issue, value, expected=None) -> InvalidNetCDFEntry:
2625
)
2726

2827

29-
def split_on_aos(metadata: IDSMetadata):
30-
paths = []
31-
curpath = metadata.name
32-
33-
item = metadata
34-
while item._parent.data_type is not None:
35-
item = item._parent
36-
if item.data_type is IDSDataType.STRUCT_ARRAY:
37-
paths.append(curpath)
38-
curpath = item.name
39-
else:
40-
curpath = f"{item.name}/{curpath}"
41-
paths.append(curpath)
42-
return paths[::-1]
43-
44-
45-
IndexedNode = Tuple[Tuple[int, ...], IDSBase]
46-
47-
48-
def tree_iter(structure: IDSStructure, metadata: IDSMetadata) -> Iterator[IndexedNode]:
49-
paths = split_on_aos(metadata)
50-
if len(paths) == 1:
51-
yield (), structure[paths[0]]
52-
else:
53-
yield from _tree_iter(structure, paths, ())
54-
55-
56-
def _tree_iter(
57-
structure: IDSStructure, paths: List[str], curindex: Tuple[int, ...]
58-
) -> Iterator[IndexedNode]:
59-
aos_path, *paths = paths
60-
aos = structure[aos_path]
61-
62-
if len(paths) == 1:
63-
path = paths[0]
64-
for i, node in enumerate(aos):
65-
yield curindex + (i,), node[path]
66-
67-
else:
68-
for i, node in enumerate(aos):
69-
yield from _tree_iter(node, paths, curindex + (i,))
70-
71-
7228
class NC2IDS:
7329
"""Class responsible for reading an IDS from a NetCDF group."""
7430

@@ -169,7 +125,7 @@ def run(self, lazy: bool) -> None:
169125
if metadata.data_type is IDSDataType.STRUCT_ARRAY:
170126
if "sparse" in var.ncattrs():
171127
shapes = self.group[var_name + ":shape"][()]
172-
for index, node in tree_iter(self.ids, target_metadata):
128+
for index, node in indexed_tree_iter(self.ids, target_metadata):
173129
node.resize(shapes[index][0])
174130

175131
else:
@@ -178,7 +134,7 @@ def run(self, lazy: bool) -> None:
178134
metadata.path_string, self.homogeneous_time
179135
)[-1]
180136
size = self.group.dimensions[dim].size
181-
for _, node in tree_iter(self.ids, target_metadata):
137+
for _, node in indexed_tree_iter(self.ids, target_metadata):
182138
node.resize(size)
183139

184140
continue
@@ -190,15 +146,15 @@ def run(self, lazy: bool) -> None:
190146
if "sparse" in var.ncattrs():
191147
if metadata.ndim:
192148
shapes = self.group[var_name + ":shape"][()]
193-
for index, node in tree_iter(self.ids, target_metadata):
149+
for index, node in indexed_tree_iter(self.ids, target_metadata):
194150
shape = shapes[index]
195151
if shape.all():
196152
# NOTE: bypassing IDSPrimitive.value.setter logic
197153
node._IDSPrimitive__value = data[
198154
index + tuple(map(slice, shape))
199155
]
200156
else:
201-
for index, node in tree_iter(self.ids, target_metadata):
157+
for index, node in indexed_tree_iter(self.ids, target_metadata):
202158
value = data[index]
203159
if value != getattr(var, "_FillValue", None):
204160
# NOTE: bypassing IDSPrimitive.value.setter logic
@@ -211,7 +167,7 @@ def run(self, lazy: bool) -> None:
211167
self.ids[target_metadata.path].value = data
212168

213169
else:
214-
for index, node in tree_iter(self.ids, target_metadata):
170+
for index, node in indexed_tree_iter(self.ids, target_metadata):
215171
# NOTE: bypassing IDSPrimitive.value.setter logic
216172
node._IDSPrimitive__value = data[index]
217173

0 commit comments

Comments
 (0)