Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions nibabies/interfaces/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import os
from pathlib import Path

import pooch
from nipype.interfaces.base import (
DynamicTraitedSpec,
File,
SimpleInterface,
TraitedSpec,
traits,
)

import nibabies


class _RetrievePoochFilesInputSpec(DynamicTraitedSpec):
intermediate = traits.Str(required=True, desc='the intermediate space')
target = traits.Str(required=True, desc='the target space')


class _RetrievePoochFilesOutputSpec(TraitedSpec):
int2tgt_xfm = File(desc='Intermediate to target transform')
tgt2int_xfm = File(desc='Target to intermediate transform')


class RetrievePoochFiles(SimpleInterface):
input_spec = _RetrievePoochFilesInputSpec
output_spec = _RetrievePoochFilesOutputSpec

def _run_interface(self, runtime):
int2tgt, tgt2int = _retrieve_xfms(self.inputs.intermediate, self.inputs.target)
self._results['int2tgt_xfm'] = int2tgt
self._results['tgt2int_xfm'] = tgt2int
return runtime


def _retrieve_xfms(
intermediate: str,
target: str,
):
"""Fetch transforms from the OSF repository (https://osf.io/y763j/)."""

manifest = json.loads(nibabies.data.load('xfm_manifest.json').read_text())

def sanitize(space):
# MNIInfant:cohort-1 -> MNIInfant+1
return space.replace(':cohort-', '+')

intmd = sanitize(intermediate)
tgt = sanitize(target)

cache_dir = Path(os.getenv('NIBABIES_POOCH_DIR', Path.cwd()))

int2std_name = f'from-{intmd}_to-{tgt}_xfm.h5'
int2std_meta = manifest[int2std_name]
int2std = pooch.retrieve(
url=int2std_meta['url'],
path=cache_dir,
known_hash=int2std_meta['hash'],
fname=int2std_name,
)

std2int_name = f'from-{tgt}_to-{intmd}_xfm.h5'
std2int_meta = manifest[std2int_name]
std2int = pooch.retrieve(
url=std2int_meta['url'],
path=cache_dir,
known_hash=std2int_meta['hash'],
fname=std2int_name,
)

return int2std, std2int
19 changes: 19 additions & 0 deletions nibabies/interfaces/tests/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

from nibabies.interfaces.download import RetrievePoochFiles


def test_RetrievePoochFiles(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
getter = RetrievePoochFiles(intermediate='MNIInfant:cohort-1', target='MNI152NLin6Asym')
outputs = getter.run().outputs
assert Path(outputs.int2tgt_xfm).exists()
assert Path(outputs.tgt2int_xfm).exists()

cache = tmp_path / 'mycache'
monkeypatch.setenv('NIBABIES_POOCH_DIR', cache)
getter = RetrievePoochFiles(intermediate='MNIInfant:cohort-1', target='MNI152NLin6Asym')
outputs = getter.run().outputs

assert Path(outputs.int2tgt_xfm) == cache / 'from-MNIInfant+1_to-MNI152NLin6Asym_xfm.h5'
assert Path(outputs.tgt2int_xfm) == cache / 'from-MNI152NLin6Asym_to-MNIInfant+1_xfm.h5'
51 changes: 8 additions & 43 deletions nibabies/workflows/anatomical/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
further use in downstream nodes.

"""
from nibabies.interfaces.download import RetrievePoochFiles

Check warning on line 348 in nibabies/workflows/anatomical/registration.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/registration.py#L348

Added line #L348 was not covered by tests
from nibabies.interfaces.patches import CompositeTransformUtil

ntpls = len(templates)
Expand Down Expand Up @@ -405,11 +406,9 @@
outputnode = pe.Node(niu.IdentityInterface(fields=out_fields), name='outputnode')

intermed_xfms = pe.MapNode(
niu.Function(
function=_load_intermediate_xfms, output_names=['int2std_xfm', 'std2int_xfm']
),
name='intermed_xfms',
iterfield=['std'],
RetrievePoochFiles(),
name='retrieve_xfms',
iterfield=['target'],
run_without_submitting=True,
)

Expand Down Expand Up @@ -464,10 +463,10 @@
# Transform concatenation
(inputnode, dis_anat2int, [('anat2int_xfm', 'in_file')]),
(inputnode, dis_int2anat, [('int2anat_xfm', 'in_file')]),
(inputnode, intermed_xfms, [('intermediate', 'intermediate')]),
(inputnode, intermed_xfms, [('template', 'std')]),
(intermed_xfms, dis_int2std, [('int2std_xfm', 'in_file')]),
(intermed_xfms, dis_std2int, [('std2int_xfm', 'in_file')]),
(inputnode, intermed_xfms, [('intermediate', 'intermediate'),
('template', 'target')]),
(intermed_xfms, dis_int2std, [('int2tgt_xfm', 'in_file')]),
(intermed_xfms, dis_std2int, [('tgt2int_xfm', 'in_file')]),
(dis_anat2int, order_anat2std, [
('affine_transform', 'in1'),
('displacement_field', 'in2'),
Expand Down Expand Up @@ -505,40 +504,6 @@
return workflow


def _load_intermediate_xfms(intermediate, std):
"""Fetch transforms from the OSF repository (https://osf.io/y763j/)."""
import json
from pathlib import Path

import pooch

from nibabies.data import load

xfms = json.loads(load('xfm_manifest.json').read_text())
# MNIInfant:cohort-1 -> MNIInfant+1
intmed = intermediate.replace(':cohort-', '+')

int2std_name = f'from-{intmed}_to-{std}_xfm.h5'
int2std_meta = xfms[int2std_name]
int2std = pooch.retrieve(
url=int2std_meta['url'],
path=Path.cwd(),
known_hash=int2std_meta['hash'],
fname=int2std_name,
)

std2int_name = f'from-{std}_to-{intmed}_xfm.h5'
std2int_meta = xfms[std2int_name]
std2int = pooch.retrieve(
url=std2int_meta['url'],
path=Path.cwd(),
known_hash=std2int_meta['hash'],
fname=std2int_name,
)

return int2std, std2int


def _create_inverse_composite(in_file, out_file='inverse_composite.h5'):
"""
Build a composite transform with SimpleITK.
Expand Down