diff --git a/nibabel/__init__.py b/nibabel/__init__.py index fca22ccc99..25b14d398f 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -72,6 +72,7 @@ def setup_test(): from .freesurfer import MGHImage from .funcs import (squeeze_image, concat_images, four_to_three, as_closest_canonical) +from .spatialimages import image_like from .orientations import (io_orientation, orientation_affine, flip_axis, OrientationError, apply_orientation, aff2axcodes) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index ede0820065..4f658fa91a 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -537,7 +537,7 @@ def get_affine(self): return self.affine @classmethod - def from_image(klass, img): + def from_image(klass, img, data=None): ''' Class method to create new instance of own class from `img` Parameters @@ -551,7 +551,7 @@ def from_image(klass, img): cimg : ``spatialimage`` instance Image, of our own class ''' - return klass(img.dataobj, + return klass(img.dataobj if data is None else data, img.affine, klass.header_class.from_header(img.header), extra=img.extra.copy()) @@ -634,3 +634,10 @@ def as_reoriented(self, ornt): new_aff = self.affine.dot(inv_ornt_aff(ornt, self.shape)) return self.__class__(t_arr, new_aff, self.header) + + +def image_like(img, data): + ''' Create new SpatialImage with metadata of `img`, and data + contained in `data`. + ''' + return img.from_image(img, data) diff --git a/nibabel/tests/test_spatialimages.py b/nibabel/tests/test_spatialimages.py index b0f571023d..b8569543b4 100644 --- a/nibabel/tests/test_spatialimages.py +++ b/nibabel/tests/test_spatialimages.py @@ -16,7 +16,7 @@ from io import BytesIO from ..spatialimages import (SpatialHeader, SpatialImage, HeaderDataError, - Header, ImageDataError) + Header, ImageDataError, image_like) from ..imageclasses import spatial_axes_first from unittest import TestCase @@ -659,3 +659,13 @@ class MyHeader(Header): MyHeader() assert_equal(len(w), 1) + + +def test_image_like(): + zeros = SpatialImage(np.zeros((2, 3, 4)), np.eye(4)) + ones = image_like(zeros, np.ones((2, 3, 4))) + + assert np.all(ones.dataobj != zeros.dataobj) + assert np.all(ones.affine == zeros.affine) + assert ones.header == zeros.header + assert ones.extra == zeros.extra