Skip to content

Commit e93a3ac

Browse files
authored
Merge pull request #467 from nipreps/enh/pooch-cache
RF: Convert pooch retrieval to interface, allow setting cache dir
2 parents ec18eaf + 34195b9 commit e93a3ac

File tree

3 files changed

+100
-43
lines changed

3 files changed

+100
-43
lines changed

nibabies/interfaces/download.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
import os
3+
from pathlib import Path
4+
5+
import pooch
6+
from nipype.interfaces.base import (
7+
DynamicTraitedSpec,
8+
File,
9+
SimpleInterface,
10+
TraitedSpec,
11+
traits,
12+
)
13+
14+
import nibabies
15+
16+
17+
class _RetrievePoochFilesInputSpec(DynamicTraitedSpec):
18+
intermediate = traits.Str(required=True, desc='the intermediate space')
19+
target = traits.Str(required=True, desc='the target space')
20+
21+
22+
class _RetrievePoochFilesOutputSpec(TraitedSpec):
23+
int2tgt_xfm = File(desc='Intermediate to target transform')
24+
tgt2int_xfm = File(desc='Target to intermediate transform')
25+
26+
27+
class RetrievePoochFiles(SimpleInterface):
28+
input_spec = _RetrievePoochFilesInputSpec
29+
output_spec = _RetrievePoochFilesOutputSpec
30+
31+
def _run_interface(self, runtime):
32+
int2tgt, tgt2int = _retrieve_xfms(self.inputs.intermediate, self.inputs.target)
33+
self._results['int2tgt_xfm'] = int2tgt
34+
self._results['tgt2int_xfm'] = tgt2int
35+
return runtime
36+
37+
38+
def _retrieve_xfms(
39+
intermediate: str,
40+
target: str,
41+
):
42+
"""Fetch transforms from the OSF repository (https://osf.io/y763j/)."""
43+
44+
manifest = json.loads(nibabies.data.load('xfm_manifest.json').read_text())
45+
46+
def sanitize(space):
47+
# MNIInfant:cohort-1 -> MNIInfant+1
48+
return space.replace(':cohort-', '+')
49+
50+
intmd = sanitize(intermediate)
51+
tgt = sanitize(target)
52+
53+
cache_dir = Path(os.getenv('NIBABIES_POOCH_DIR', Path.cwd()))
54+
55+
int2std_name = f'from-{intmd}_to-{tgt}_xfm.h5'
56+
int2std_meta = manifest[int2std_name]
57+
int2std = pooch.retrieve(
58+
url=int2std_meta['url'],
59+
path=cache_dir,
60+
known_hash=int2std_meta['hash'],
61+
fname=int2std_name,
62+
)
63+
64+
std2int_name = f'from-{tgt}_to-{intmd}_xfm.h5'
65+
std2int_meta = manifest[std2int_name]
66+
std2int = pooch.retrieve(
67+
url=std2int_meta['url'],
68+
path=cache_dir,
69+
known_hash=std2int_meta['hash'],
70+
fname=std2int_name,
71+
)
72+
73+
return int2std, std2int
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pathlib import Path
2+
3+
from nibabies.interfaces.download import RetrievePoochFiles
4+
5+
6+
def test_RetrievePoochFiles(tmp_path, monkeypatch):
7+
monkeypatch.chdir(tmp_path)
8+
getter = RetrievePoochFiles(intermediate='MNIInfant:cohort-1', target='MNI152NLin6Asym')
9+
outputs = getter.run().outputs
10+
assert Path(outputs.int2tgt_xfm).exists()
11+
assert Path(outputs.tgt2int_xfm).exists()
12+
13+
cache = tmp_path / 'mycache'
14+
monkeypatch.setenv('NIBABIES_POOCH_DIR', cache)
15+
getter = RetrievePoochFiles(intermediate='MNIInfant:cohort-1', target='MNI152NLin6Asym')
16+
outputs = getter.run().outputs
17+
18+
assert Path(outputs.int2tgt_xfm) == cache / 'from-MNIInfant+1_to-MNI152NLin6Asym_xfm.h5'
19+
assert Path(outputs.tgt2int_xfm) == cache / 'from-MNI152NLin6Asym_to-MNIInfant+1_xfm.h5'

nibabies/workflows/anatomical/registration.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def init_concat_registrations_wf(
345345
further use in downstream nodes.
346346
347347
"""
348+
from nibabies.interfaces.download import RetrievePoochFiles
348349
from nibabies.interfaces.patches import CompositeTransformUtil
349350

350351
ntpls = len(templates)
@@ -405,11 +406,9 @@ def init_concat_registrations_wf(
405406
outputnode = pe.Node(niu.IdentityInterface(fields=out_fields), name='outputnode')
406407

407408
intermed_xfms = pe.MapNode(
408-
niu.Function(
409-
function=_load_intermediate_xfms, output_names=['int2std_xfm', 'std2int_xfm']
410-
),
411-
name='intermed_xfms',
412-
iterfield=['std'],
409+
RetrievePoochFiles(),
410+
name='retrieve_xfms',
411+
iterfield=['target'],
413412
run_without_submitting=True,
414413
)
415414

@@ -464,10 +463,10 @@ def init_concat_registrations_wf(
464463
# Transform concatenation
465464
(inputnode, dis_anat2int, [('anat2int_xfm', 'in_file')]),
466465
(inputnode, dis_int2anat, [('int2anat_xfm', 'in_file')]),
467-
(inputnode, intermed_xfms, [('intermediate', 'intermediate')]),
468-
(inputnode, intermed_xfms, [('template', 'std')]),
469-
(intermed_xfms, dis_int2std, [('int2std_xfm', 'in_file')]),
470-
(intermed_xfms, dis_std2int, [('std2int_xfm', 'in_file')]),
466+
(inputnode, intermed_xfms, [('intermediate', 'intermediate'),
467+
('template', 'target')]),
468+
(intermed_xfms, dis_int2std, [('int2tgt_xfm', 'in_file')]),
469+
(intermed_xfms, dis_std2int, [('tgt2int_xfm', 'in_file')]),
471470
(dis_anat2int, order_anat2std, [
472471
('affine_transform', 'in1'),
473472
('displacement_field', 'in2'),
@@ -505,40 +504,6 @@ def init_concat_registrations_wf(
505504
return workflow
506505

507506

508-
def _load_intermediate_xfms(intermediate, std):
509-
"""Fetch transforms from the OSF repository (https://osf.io/y763j/)."""
510-
import json
511-
from pathlib import Path
512-
513-
import pooch
514-
515-
from nibabies.data import load
516-
517-
xfms = json.loads(load('xfm_manifest.json').read_text())
518-
# MNIInfant:cohort-1 -> MNIInfant+1
519-
intmed = intermediate.replace(':cohort-', '+')
520-
521-
int2std_name = f'from-{intmed}_to-{std}_xfm.h5'
522-
int2std_meta = xfms[int2std_name]
523-
int2std = pooch.retrieve(
524-
url=int2std_meta['url'],
525-
path=Path.cwd(),
526-
known_hash=int2std_meta['hash'],
527-
fname=int2std_name,
528-
)
529-
530-
std2int_name = f'from-{std}_to-{intmed}_xfm.h5'
531-
std2int_meta = xfms[std2int_name]
532-
std2int = pooch.retrieve(
533-
url=std2int_meta['url'],
534-
path=Path.cwd(),
535-
known_hash=std2int_meta['hash'],
536-
fname=std2int_name,
537-
)
538-
539-
return int2std, std2int
540-
541-
542507
def _create_inverse_composite(in_file, out_file='inverse_composite.h5'):
543508
"""
544509
Build a composite transform with SimpleITK.

0 commit comments

Comments
 (0)