30
30
31
31
from eddymotion .data .dmri import DWI
32
32
from eddymotion .estimator import EddyMotionEstimator
33
+ from eddymotion .registration .utils import displacements_within_mask
33
34
34
35
35
- def test_proximity_estimator_trivial_model (datadir ):
36
+ def test_proximity_estimator_trivial_model (datadir , tmp_path ):
36
37
"""Check the proximity of transforms estimated by the estimator with a trivial B0 model."""
37
38
38
39
dwdata = DWI .from_filename (datadir / "dwi.h5" )
39
40
b0nii = nb .Nifti1Image (dwdata .bzero , dwdata .affine , None )
41
+ masknii = nb .Nifti1Image (dwdata .brainmask .astype (np .uint8 ), dwdata .affine , None )
40
42
41
43
# Generate a list of large-yet-plausible bulk-head motion.
42
44
xfms = nt .linear .LinearTransformsMapping (
@@ -56,8 +58,8 @@ def test_proximity_estimator_trivial_model(datadir):
56
58
moved_nii = (~ xfms ).apply (b0nii , reference = b0nii )
57
59
58
60
# Uncomment to see the moved dataset
59
- # moved_nii.to_filename(tmp_path / "test.nii.gz")
60
- # xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz")
61
+ moved_nii .to_filename (tmp_path / "test.nii.gz" )
62
+ xfms .apply (moved_nii ).to_filename (tmp_path / "ground_truth.nii.gz" )
61
63
62
64
# Wrap into dataset object
63
65
dwi_motion = DWI (
@@ -70,7 +72,7 @@ def test_proximity_estimator_trivial_model(datadir):
70
72
71
73
estimator = EddyMotionEstimator ()
72
74
em_affines = estimator .estimate (
73
- dwdata = dwi_motion ,
75
+ data = dwi_motion ,
74
76
models = ("b0" ,),
75
77
seed = None ,
76
78
align_kwargs = {
@@ -81,14 +83,16 @@ def test_proximity_estimator_trivial_model(datadir):
81
83
)
82
84
83
85
# Uncomment to see the realigned dataset
84
- # nt.linear.LinearTransformsMapping(
85
- # em_affines,
86
- # reference=b0nii,
87
- # ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz")
86
+ nt .linear .LinearTransformsMapping (
87
+ em_affines ,
88
+ reference = b0nii ,
89
+ ).apply (moved_nii ).to_filename (tmp_path / "realigned.nii.gz" )
88
90
89
91
# For each moved b0 volume
90
92
coords = xfms .reference .ndcoords .T
91
93
for i , est in enumerate (em_affines ):
92
- xfm = nt .linear .Affine (xfms .matrix [i ], reference = b0nii )
93
- est = nt .linear .Affine (est , reference = b0nii )
94
- assert np .sqrt (((xfm .map (coords ) - est .map (coords )) ** 2 ).sum (1 )).mean () < 0.2
94
+ assert displacements_within_mask (
95
+ masknii ,
96
+ nt .linear .Affine (est ),
97
+ xfms [i ],
98
+ ).max () < 0.2
0 commit comments