Skip to content

Commit 59fab27

Browse files
committed
ENH: Get the nd_sort method mostly working w/ basic tests
1 parent 583e0aa commit 59fab27

File tree

2 files changed

+120
-19
lines changed

2 files changed

+120
-19
lines changed

nibabel/metasum.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ def __getitem__(self, value):
100100
ba = self._val_bitarrs[value]
101101
return list(self._extract_indices(ba))
102102

103+
def first(self, value):
104+
'''Return the first index where this value appears'''
105+
if self._const_val == value:
106+
return 0
107+
idx = self._unique_vals.get(value)
108+
if idx is not None:
109+
return idx
110+
return self._val_bitarrs[value].index(True)
111+
103112
def values(self):
104113
'''Generate each unique value that has been seen'''
105114
if self._const_val is not _NoValue:
@@ -339,8 +348,15 @@ def get_block_size(self):
339348
return None
340349
return block_size
341350

342-
def is_subpartition(self, other):
343-
''''''
351+
def is_orthogonal(self, other, size=1):
352+
'''Check our value's indices overlaps each from `other` exactly `size` times
353+
'''
354+
other_bas = {v: other.get_mask(v) for v in other.values()}
355+
for val in self.values():
356+
for other_val, other_ba in other_bas.items():
357+
if self.count(val, mask=other_ba) != size:
358+
return False
359+
return True
344360

345361
def _extract_indices(self, ba):
346362
'''Generate integer indices from bitarray representation'''
@@ -416,7 +432,7 @@ def _extend_const(self, other):
416432

417433

418434
class DimTypes(IntEnum):
419-
'''Enmerate the three types of nD dimensions'''
435+
'''Enumerate the three types of nD dimensions'''
420436
SLICE = 1
421437
TIME = 2
422438
PARAM = 3
@@ -556,8 +572,9 @@ def nd_sort(self, dims):
556572
last_dim = dim
557573

558574
# Pull out info about different types of dims
559-
n_slices = None
560-
n_vol = None
575+
n_input = self._n_input
576+
total_vol = None
577+
slice_dim = None
561578
time_dim = None
562579
param_dims = []
563580
n_params = []
@@ -568,9 +585,10 @@ def nd_sort(self, dims):
568585
dim_vidx = self._v_idxs[dim.key]
569586
dim_type = dim.dim_type
570587
if dim_type is DimTypes.SLICE:
588+
slice_dim = dim
571589
n_slices = len(dim_vidx)
572-
n_vol = dim_vidx.get_block_size()
573-
if n_vol is None:
590+
total_vol = dim_vidx.get_block_size()
591+
if total_vol is None:
574592
raise NdSortError("There are missing or extra slices")
575593
shape.append(n_slices)
576594
curr_size *= n_slices
@@ -583,29 +601,39 @@ def nd_sort(self, dims):
583601
n_param = len(dim_vidx)
584602
n_params.append(n_param)
585603
total_params *= n_param
586-
if n_vol is None:
587-
n_vol = self._n_input
604+
if total_vol is None:
605+
total_vol = n_input
588606

589-
# Size of the time dimension must be infered from the size of the other dims
607+
# Size of the time dimension must be inferred from the size of the other dims
590608
n_time = 1
609+
prev_dim = slice_dim
591610
if time_dim is not None:
592-
n_time, rem = divmod(n_vol, total_params)
611+
n_time, rem = divmod(total_vol, total_params)
593612
if rem != 0:
594-
raise NdSortError(f"The combined parameters don't evenly divide inputs")
613+
raise NdSortError("The combined parameters don't evenly divide inputs")
595614
shape.append(n_time)
596615
curr_size *= n_time
616+
prev_dim = time_dim
597617

598-
# Complete the "shape", and do a more detailed check that our param dims make sense
618+
# Complete the "shape", and do a more detailed check that our dims make sense
599619
for dim, n_param in zip(param_dims, n_params):
600620
dim_vidx = self._v_idxs[dim.key]
601-
if dim_vidx.get_block_size() != curr_size:
621+
if dim_vidx.get_block_size() != n_input // n_param:
602622
raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
623+
if prev_dim is not None and prev_dim.dim_type != DimTypes.TIME:
624+
count_per = (curr_size // shape[-1]) * (n_input // (curr_size * n_param))
625+
if not self._v_idxs[prev_dim.key].is_orthogonal(dim_vidx, count_per):
626+
raise NdSortError("The dimensions are not orthogonal")
603627
shape.append(n_param)
604628
curr_size *= n_param
629+
prev_dim = dim
605630

606631
# Extract dim keys for each input and do the actual sort
607632
sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims)))
608-
for idx in range(self._n_input)]
633+
for idx in range(n_input)]
609634
sort_keys.sort(key=lambda x: x[1])
610635

611-
# TODO: Finish this
636+
# TODO: If we have non-singular time dimension we need to do some additional
637+
# validation checks here after sorting.
638+
639+
return tuple(shape), [x[0] for x in sort_keys]

nibabel/tests/test_metasum.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from ..metasum import MetaSummary, ValueIndices
1+
import random
22

33
import pytest
4+
import numpy as np
5+
6+
from ..metasum import DimIndex, DimTypes, MetaSummary, ValueIndices
47

58

69
vidx_test_patterns = ([0] * 8,
@@ -14,15 +17,15 @@
1417

1518
@pytest.mark.parametrize("in_list", vidx_test_patterns)
1619
def test_value_indices_basics(in_list):
17-
'''Test we can roundtrip list -> ValueIndices -> list'''
20+
'''Test basic ValueIndices behavior'''
1821
vidx = ValueIndices(in_list)
1922
assert vidx.n_input == len(in_list)
2023
assert len(vidx) == len(set(in_list))
2124
assert sorted(vidx.values()) == sorted(list(set(in_list)))
2225
for val in vidx.values():
2326
assert vidx.count(val) == in_list.count(val)
2427
for in_idx in vidx[val]:
25-
assert in_list[in_idx] == val
28+
assert in_list[in_idx] == val == vidx.get_value(in_idx)
2629
out_list = vidx.to_list()
2730
assert in_list == out_list
2831

@@ -78,3 +81,73 @@ def test_meta_summary_basics(in_dicts):
7881
assert out_dict == in_dicts[in_idx]
7982
for key, in_val in in_dicts[in_idx].items():
8083
assert in_val == msum.get_val(in_idx, key)
84+
85+
86+
def _make_nd_meta(shape, dim_info, const_meta=None):
87+
if const_meta is None:
88+
const_meta = {'series_number': '5'}
89+
meta_seq = []
90+
for nd_idx in np.ndindex(*shape):
91+
curr_meta = {}
92+
curr_meta.update(const_meta)
93+
for dim, dim_idx in zip(dim_info, nd_idx):
94+
curr_meta[dim.key] = dim_idx
95+
meta_seq.append(curr_meta)
96+
return meta_seq
97+
98+
99+
ndsort_test_args = (((3,),
100+
(DimIndex(DimTypes.SLICE, 'slice_location'),),
101+
None),
102+
((3, 5),
103+
(DimIndex(DimTypes.SLICE, 'slice_location'),
104+
DimIndex(DimTypes.TIME, 'acq_time')),
105+
None),
106+
((3, 5),
107+
(DimIndex(DimTypes.SLICE, 'slice_location'),
108+
DimIndex(DimTypes.PARAM, 'inversion_time')),
109+
None),
110+
((3, 5, 7),
111+
(DimIndex(DimTypes.SLICE, 'slice_location'),
112+
DimIndex(DimTypes.TIME, 'acq_time'),
113+
DimIndex(DimTypes.PARAM, 'echo_time')),
114+
None),
115+
((3, 5, 7),
116+
(DimIndex(DimTypes.SLICE, 'slice_location'),
117+
DimIndex(DimTypes.PARAM, 'inversion_time'),
118+
DimIndex(DimTypes.PARAM, 'echo_time')),
119+
None),
120+
((5, 3),
121+
(DimIndex(DimTypes.TIME, 'acq_time'),
122+
DimIndex(DimTypes.PARAM, 'echo_time')),
123+
None),
124+
((3, 5, 7),
125+
(DimIndex(DimTypes.TIME, 'acq_time'),
126+
DimIndex(DimTypes.PARAM, 'inversion_time'),
127+
DimIndex(DimTypes.PARAM, 'echo_time')),
128+
None),
129+
((5, 7),
130+
(DimIndex(DimTypes.PARAM, 'inversion_time'),
131+
DimIndex(DimTypes.PARAM, 'echo_time')),
132+
None),
133+
((5, 7, 3),
134+
(DimIndex(DimTypes.PARAM, 'inversion_time'),
135+
DimIndex(DimTypes.PARAM, 'echo_time'),
136+
DimIndex(DimTypes.PARAM, 'repetition_time')),
137+
None),
138+
)
139+
140+
141+
@pytest.mark.parametrize("shape,dim_info,const_meta", ndsort_test_args)
142+
def test_ndsort(shape, dim_info, const_meta):
143+
meta_seq = _make_nd_meta(shape, dim_info, const_meta)
144+
rand_idx_seq = [(i, m) for i, m in enumerate(meta_seq)]
145+
# TODO: Use some pytest plugin to manage randomness? Just use fixed seed?
146+
random.shuffle(rand_idx_seq)
147+
rand_idx = [x[0] for x in rand_idx_seq]
148+
rand_seq = [x[1] for x in rand_idx_seq]
149+
msum = MetaSummary()
150+
for meta in rand_seq:
151+
msum.append(meta)
152+
out_shape, out_idxs = msum.nd_sort(dim_info)
153+
assert shape == out_shape

0 commit comments

Comments
 (0)