Skip to content

Commit 7839dc3

Browse files
authored
Merge pull request #163 from oesteban/fix/template-ids-cohort
ENH: Fix template keys output in normalization workflow, when cohort present
2 parents 557b689 + 0e5595e commit 7839dc3

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

smriprep/workflows/norm.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,29 @@ def init_anat_norm_wf(
133133
workflow.__desc__ += (', ', '.')[template == templates[-1][0]]
134134

135135
inputnode = pe.Node(niu.IdentityInterface(fields=[
136-
'moving_image', 'moving_mask', 'moving_segmentation', 'moving_tpms',
137-
'lesion_mask', 'orig_t1w']),
138-
name='inputnode')
139-
out_fields = ['standardized', 'anat2std_xfm', 'std2anat_xfm',
140-
'std_mask', 'std_dseg', 'std_tpms', 'template', 'template_spec']
136+
'lesion_mask',
137+
'moving_image',
138+
'moving_mask',
139+
'moving_segmentation',
140+
'moving_tpms',
141+
'orig_t1w',
142+
'template',
143+
]), name='inputnode')
144+
inputnode.iterables = [('template', templates)]
145+
146+
out_fields = [
147+
'anat2std_xfm',
148+
'standardized',
149+
'std2anat_xfm',
150+
'std_dseg',
151+
'std_mask',
152+
'std_tpms',
153+
'template',
154+
'template_spec',
155+
]
141156
poutputnode = pe.Node(niu.IdentityInterface(fields=out_fields), name='poutputnode')
142157

143158
split_desc = pe.Node(TemplateDesc(), run_without_submitting=True, name='split_desc')
144-
split_desc.iterables = [('template', templates)]
145159

146160
tf_select = pe.Node(TemplateFlowSelect(resolution=1 + debug),
147161
name='tf_select', run_without_submitting=True)
@@ -169,6 +183,8 @@ def init_anat_norm_wf(
169183
iterfield=['input_image'], name='std_tpms')
170184

171185
workflow.connect([
186+
(inputnode, split_desc, [('template', 'template')]),
187+
(inputnode, poutputnode, [('template', 'template')]),
172188
(inputnode, trunc_mov, [('moving_image', 'op1')]),
173189
(inputnode, registration, [
174190
('moving_mask', 'moving_mask'),
@@ -198,13 +214,12 @@ def init_anat_norm_wf(
198214
(std_mask, poutputnode, [('output_image', 'std_mask')]),
199215
(std_dseg, poutputnode, [('output_image', 'std_dseg')]),
200216
(std_tpms, poutputnode, [('output_image', 'std_tpms')]),
201-
(split_desc, poutputnode, [('name', 'template'),
202-
('spec', 'template_spec')]),
217+
(split_desc, poutputnode, [('spec', 'template_spec')]),
203218
])
204219

205220
# Provide synchronized output
206221
outputnode = pe.JoinNode(niu.IdentityInterface(fields=out_fields),
207-
name='outputnode', joinsource='split_desc')
222+
name='outputnode', joinsource='inputnode')
208223
workflow.connect([
209224
(poutputnode, outputnode, [(f, f) for f in out_fields]),
210225
])

smriprep/workflows/outputs.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def init_anat_reports_wf(reportlets_dir, freesurfer,
6464
name='ds_std_t1w_report', run_without_submitting=True)
6565

6666
workflow.connect([
67-
(inputnode, tf_select, [('template', 'template'),
67+
(inputnode, tf_select, [(('template', _drop_cohort), 'template'),
6868
('template_spec', 'template_spec')]),
6969
(inputnode, norm_rpt, [('template', 'before_label')]),
7070
(inputnode, norm_msk, [('std_t1w', 'after'),
@@ -74,7 +74,7 @@ def init_anat_reports_wf(reportlets_dir, freesurfer,
7474
(norm_msk, norm_rpt, [('before', 'before'),
7575
('after', 'after')]),
7676
(inputnode, ds_std_t1w_report, [
77-
('template', 'space'),
77+
(('template', _fmt_cohort), 'space'),
7878
('source_file', 'source_file')]),
7979
(norm_rpt, ds_std_t1w_report, [('out_report', 'in_file')]),
8080
])
@@ -193,23 +193,23 @@ def init_anat_derivatives_wf(bids_root, freesurfer, num_t1w, output_dir,
193193
# Template
194194
(inputnode, ds_t1w_tpl_warp, [
195195
('anat2std_xfm', 'in_file'),
196-
('template', 'to')]),
196+
(('template', _drop_cohort), 'to')]),
197197
(inputnode, ds_t1w_tpl_inv_warp, [
198198
('std2anat_xfm', 'in_file'),
199-
('template', 'from')]),
199+
(('template', _drop_cohort), 'from')]),
200200
(inputnode, ds_t1w_tpl, [
201201
('std_t1w', 'in_file'),
202-
('template', 'space')]),
202+
(('template', _fmt_cohort), 'space')]),
203203
(inputnode, ds_std_mask, [
204204
('std_mask', 'in_file'),
205-
('template', 'space'),
205+
(('template', _fmt_cohort), 'space'),
206206
(('template', _rawsources), 'RawSources')]),
207-
(inputnode, ds_std_dseg, [('template', 'space')]),
207+
(inputnode, ds_std_dseg, [(('template', _fmt_cohort), 'space')]),
208208
(inputnode, lut_std_dseg, [('std_dseg', 'in_file')]),
209209
(lut_std_dseg, ds_std_dseg, [('out', 'in_file')]),
210210
(inputnode, ds_std_tpms, [
211211
('std_tpms', 'in_file'),
212-
('template', 'space')]),
212+
(('template', _fmt_cohort), 'space')]),
213213
(t1w_name, ds_t1w_tpl_warp, [('out', 'source_file')]),
214214
(t1w_name, ds_t1w_tpl_inv_warp, [('out', 'source_file')]),
215215
(t1w_name, ds_t1w_tpl, [('out', 'source_file')]),
@@ -327,3 +327,11 @@ def _rpt_masks(mask_file, before, after, after_mask=None):
327327
nb.Nifti1Image(anii.get_fdata() * msk,
328328
anii.affine, anii.header).to_filename('after.nii.gz')
329329
return abspath('before.nii.gz'), abspath('after.nii.gz')
330+
331+
332+
def _drop_cohort(in_template):
333+
return in_template.split(':')[0]
334+
335+
336+
def _fmt_cohort(in_template):
337+
return in_template.replace(':', '_')

0 commit comments

Comments
 (0)