Skip to content

Commit dab9eb9

Browse files
committed
ENH: Add workflow to verify derivatives are compatible with anatomical reference
1 parent 1e6afd4 commit dab9eb9

File tree

3 files changed

+95
-15
lines changed

3 files changed

+95
-15
lines changed

nibabies/workflows/anatomical/fit.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
from nibabies.interfaces import DerivativesDataSink
3939
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
4040
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf, init_coreg_report_wf
41-
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf
41+
from nibabies.workflows.anatomical.preproc import (
42+
init_anat_preproc_wf,
43+
init_conform_derivative_wf,
44+
init_csf_norm_wf,
45+
)
4246
from nibabies.workflows.anatomical.registration import (
4347
init_concat_registrations_wf,
4448
init_coregistration_wf,
@@ -172,11 +176,11 @@ def init_infant_anat_fit_wf(
172176

173177
# Stage 2 - Anatomicals
174178
t1w_buffer = pe.Node(
175-
niu.IdentityInterface(fields=['t1w_preproc', 't1w_maskt1w_brain']),
179+
niu.IdentityInterface(fields=['t1w_preproc', 't1w_mask', 't1w_brain']),
176180
name='t1w_buffer',
177181
)
178182
t2w_buffer = pe.Node(
179-
niu.IdentityInterface(fields=['t2w_preproc', 't2w_maskt2w_brain', 't2w_probmap']),
183+
niu.IdentityInterface(fields=['t2w_preproc', 't2w_mask', 't2w_brain', 't2w_probmap']),
180184
name='t2w_buffer',
181185
)
182186
anat_buffer = pe.Node(
@@ -323,6 +327,7 @@ def init_infant_anat_fit_wf(
323327

324328
t1w_preproc = precomputed.get('t1w_preproc')
325329
t2w_preproc = precomputed.get('t2w_preproc')
330+
anat_preproc = precomputed.get(f'{anat}_preproc')
326331

327332
# Stage 1: Conform & valid T1w/T2w images
328333
# Note: Since stage 1 & 2 are tightly knit together, it may be more intuitive
@@ -575,21 +580,28 @@ def init_infant_anat_fit_wf(
575580
'A pre-computed T1w brain mask was provided as input and used throughout the '
576581
'workflow.'
577582
)
578-
t1w_buffer.inputs.t1w_mask = t1w_mask
579583
apply_t1w_mask.inputs.in_mask = t1w_mask
580584
workflow.connect(apply_t1w_mask, 'out_file', t1w_buffer, 't1w_brain')
581585

582586
if not t1w_preproc:
587+
# Ensure compatibility with T1w template
588+
conform_t1w_mask_wf = init_conform_derivative_wf(
589+
in_file=t1w_mask, name='conform_t1w_mask_wf'
590+
)
591+
583592
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
584593
t1w_n4_wf = init_anat_preproc_wf(name='t1w_n4_wf')
585594
workflow.connect([
595+
(t1w_validate, conform_t1w_mask_wf, [('out_file', 'inputnode.ref_file')]),
596+
(conform_t1w_mask_wf, t1w_buffer, [('outputnode.out_file', 't1w_mask')]),
586597
(t1w_validate, t1w_n4_wf, [('out_file', 'inputnode.in_anat')]),
587598
(t1w_n4_wf, t1w_buffer, [('outputnode.anat_preproc', 't1w_preproc')]),
588599
(t1w_n4_wf, apply_t1w_mask, [('outputnode.anat_preproc', 'in_file')]),
589600
]) # fmt:skip
590601
else:
591602
LOGGER.info('ANAT Skipping T1w masking')
592603
workflow.connect(t1w_validate, 'out_file', apply_t1w_mask, 'in_file')
604+
t1w_buffer.inputs.t1w_mask = t1w_mask
593605

594606
# T2w masking logic:
595607
#
@@ -701,21 +713,29 @@ def init_infant_anat_fit_wf(
701713
'A pre-computed T2w brain mask was provided as input and used throughout the '
702714
'workflow.'
703715
)
704-
t2w_buffer.inputs.t2w_mask = t2w_mask
705716
apply_t2w_mask.inputs.in_mask = t2w_mask
706717
workflow.connect(apply_t2w_mask, 'out_file', t2w_buffer, 't2w_brain')
707718

708719
if not t2w_preproc:
720+
# Ensure compatibility with T2w template
721+
conform_t2w_mask_wf = init_conform_derivative_wf(
722+
in_file=t2w_mask,
723+
name='conform_t2w_mask_wf',
724+
)
725+
709726
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
710727
t2w_n4_wf = init_anat_preproc_wf(name='t2w_n4_wf')
711728
workflow.connect([
729+
(t2w_validate, conform_t2w_mask_wf, [('out_file', 'inputnode.ref_file')]),
730+
(conform_t2w_mask_wf, t2w_buffer, [('outputnode.out_file', 't2w_mask')]),
712731
(t2w_validate, t2w_n4_wf, [('out_file', 'inputnode.in_anat')]),
713732
(t2w_n4_wf, t2w_buffer, [('outputnode.anat_preproc', 't2w_preproc')]),
714733
(t2w_n4_wf, apply_t2w_mask, [('outputnode.anat_preproc', 'in_file')]),
715734
]) # fmt:skip
716735
else:
717736
LOGGER.info('ANAT Skipping T2w masking')
718737
workflow.connect(t2w_validate, 'out_file', apply_t2w_mask, 'in_file')
738+
t2w_buffer.inputs.t2w_mask = t2w_mask
719739

720740
# Stage 3: Coregistration
721741
t1w2t2w_xfm = precomputed.get('t1w2t2w_xfm')
@@ -819,7 +839,19 @@ def init_infant_anat_fit_wf(
819839

820840
if anat_aseg:
821841
LOGGER.info('ANAT Found precomputed anatomical segmentation')
822-
aseg_buffer.inputs.anat_aseg = anat_aseg
842+
# Ensure compatibility with anatomical template
843+
if not anat_preproc:
844+
conform_aseg_wf = init_conform_derivative_wf(
845+
in_file=anat_aseg,
846+
name='conform_aseg_wf',
847+
)
848+
849+
workflow.connect([
850+
(anat_buffer, conform_aseg_wf, [('anat_preproc', 'inputnode.ref_file')]),
851+
(conform_aseg_wf, aseg_buffer, [('outputnode.out_file', 'anat_aseg')]),
852+
]) # fmt:skip
853+
else:
854+
aseg_buffer.inputs.anat_aseg = anat_aseg
823855

824856
if not (anat_dseg and anat_tpms):
825857
LOGGER.info('ANAT Stage 4: Tissue segmentation')
@@ -1714,27 +1746,46 @@ def init_infant_single_anat_fit_wf(
17141746
else:
17151747
LOGGER.info(f'ANAT Found {reference_anat} brain mask')
17161748
desc += 'A pre-computed brain mask was provided as input and used throughout the workflow.'
1717-
anat_buffer.inputs.anat_mask = anat_mask
17181749
apply_mask.inputs.in_mask = anat_mask
17191750
workflow.connect(apply_mask, 'out_file', anat_buffer, 'anat_brain')
17201751

17211752
if not anat_preproc:
1753+
conform_anat_mask_wf = init_conform_derivative_wf(
1754+
in_file=anat_mask,
1755+
name='conform_anat_mask_wf',
1756+
)
1757+
17221758
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
17231759
anat_n4_wf = init_anat_preproc_wf(name='anat_n4_wf')
17241760
workflow.connect([
1761+
(anat_validate, conform_anat_mask_wf, [('out_file', 'inputnode.ref_file')]),
1762+
(conform_anat_mask_wf, anat_buffer, [('outputnode.out_file', 'anat_mask')]),
17251763
(anat_validate, anat_n4_wf, [('out_file', 'inputnode.in_anat')]),
17261764
(anat_n4_wf, anat_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
17271765
(anat_n4_wf, apply_mask, [('outputnode.anat_preproc', 'in_file')]),
17281766
]) # fmt:skip
17291767
else:
17301768
LOGGER.info(f'ANAT Skipping {reference_anat} masking')
17311769
workflow.connect(anat_validate, 'out_file', apply_mask, 'in_file')
1770+
anat_buffer.inputs.anat_mask = anat_mask
17321771

17331772
# Stage 3: Segmentation
17341773
seg_method = 'jlf' if config.execution.segmentation_atlases_dir else 'fast'
17351774
if anat_aseg:
17361775
LOGGER.info('ANAT Found precomputed anatomical segmentation')
1737-
aseg_buffer.inputs.anat_aseg = anat_aseg
1776+
# Ensure compatibility with anatomical template
1777+
if not anat_preproc:
1778+
conform_aseg_wf = init_conform_derivative_wf(
1779+
in_file=anat_aseg,
1780+
name='conform_aseg_wf',
1781+
)
1782+
1783+
workflow.connect([
1784+
(anat_buffer, conform_aseg_wf, [('anat_preproc', 'inputnode.ref_file')]),
1785+
(conform_aseg_wf, aseg_buffer, [('outputnode.out_file', 'anat_aseg')]),
1786+
]) # fmt:skip
1787+
else:
1788+
aseg_buffer.inputs.anat_aseg = anat_aseg
17381789

17391790
if not (anat_dseg and anat_tpms):
17401791
LOGGER.info('ANAT Stage 3: Tissue segmentation')

nibabies/workflows/anatomical/preproc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,41 @@ def init_csf_norm_wf(name: str = 'csf_norm_wf') -> Workflow:
9898
return workflow
9999

100100

101+
def init_conform_derivative_wf(
102+
*, in_file: str = None, name: str = 'conform_derivative_wf'
103+
) -> pe.Workflow:
104+
"""
105+
Ensure derivatives share the same space as anatomical references.
106+
107+
This workflow is used when a derivative is provided without a reference.
108+
"""
109+
from niworkflows.interfaces.header import MatchHeader
110+
from niworkflows.interfaces.images import Conform, TemplateDimensions
111+
112+
workflow = pe.Workflow(name=name)
113+
inputnode = pe.Node(niu.IdentityInterface(fields=['in_file', 'ref_file']), name='inputnode')
114+
inputnode.inputs.in_file = in_file
115+
outputnode = pe.Node(niu.IdentityInterface(fields=['out_file']), name='outputnode')
116+
117+
ref_dims = pe.Node(TemplateDimensions(), name='ref_dims')
118+
conform = pe.Node(Conform(), name='conform')
119+
# Avoid mismatch tolerance from input
120+
match_header = pe.Node(MatchHeader(), name='match_header')
121+
122+
workflow.connect([
123+
(inputnode, ref_dims, [('ref_file', 'anat_list')]),
124+
(ref_dims, conform, [
125+
('target_zooms', 'target_zooms'),
126+
('target_shape', 'target_shape'),
127+
]),
128+
(inputnode, conform, [('in_file', 'in_file')]),
129+
(conform, match_header, [('out_file', 'in_file')]),
130+
(inputnode, match_header, [('ref_file', 'reference')]),
131+
(match_header, outputnode, [('out_file', 'out_file')]),
132+
]) # fmt:skip
133+
return workflow
134+
135+
101136
def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None):
102137
"""Normalize low intensity voxels that fall within a given mask."""
103138
import nibabel as nb

nibabies/workflows/anatomical/surfaces.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from niworkflows.interfaces.freesurfer import (
1616
PatchedRobustRegister as RobustRegister,
1717
)
18-
from niworkflows.interfaces.header import MatchHeader
1918
from niworkflows.interfaces.morphology import BinaryDilation
2019
from niworkflows.interfaces.patches import FreeSurferSource
2120
from smriprep.interfaces.freesurfer import MakeMidthickness
@@ -128,9 +127,6 @@ def init_mcribs_surface_recon_wf(
128127
mask_dil = pe.Node(BinaryDilation(radius=3), name='mask_dil')
129128
mask_las = pe.Node(ReorientImage(target_orientation='LAS'), name='mask_las')
130129

131-
# N4 has low tolerance for mismatch between input / mask
132-
match_header = pe.Node(MatchHeader(), name='match_header')
133-
134130
# N4BiasCorrection occurs in MCRIBTissueSegMCRIBS (which is skipped)
135131
# Run it (with mask to rescale intensities) prior injection
136132
n4_mcribs = pe.Node(
@@ -182,9 +178,7 @@ def init_mcribs_surface_recon_wf(
182178
('subjects_dir', 'subjects_dir'),
183179
('subject_id', 'subject_id')]),
184180
(t2w_las, n4_mcribs, [('out_file', 'input_image')]),
185-
(mask_las, match_header, [('out_file', 'in_file')]),
186-
(t2w_las, match_header, [('out_file', 'reference')]),
187-
(match_header, n4_mcribs, [('out_file', 'mask_image')]),
181+
(mask_las, n4_mcribs, [('out_file', 'mask_image')]),
188182
(n4_mcribs, mcribs_recon, [('output_image', 't2w_file')]),
189183
(seg_las, mcribs_recon, [('out_file', 'segmentation_file')]),
190184
(inputnode, mcribs_postrecon, [

0 commit comments

Comments
 (0)