diff --git a/src/datasets/features/nifti.py b/src/datasets/features/nifti.py index 3b118d1cc12..f3c34d29266 100644 --- a/src/datasets/features/nifti.py +++ b/src/datasets/features/nifti.py @@ -27,7 +27,7 @@ class Nifti1ImageWrapper(nib.nifti1.Nifti1Image): def __init__(self, nifti_image: nib.nifti1.Nifti1Image): super().__init__( - dataobj=nifti_image.get_fdata(), + dataobj=nifti_image.dataobj, affine=nifti_image.affine, header=nifti_image.header, extra=nifti_image.extra, diff --git a/tests/features/test_nifti.py b/tests/features/test_nifti.py index b5f0be42f3e..527a5083c3e 100644 --- a/tests/features/test_nifti.py +++ b/tests/features/test_nifti.py @@ -128,3 +128,22 @@ def test_load_zipped_file_locally(shared_datadir): ds = load_dataset("niftifolder", data_files=nifti_path) assert isinstance(ds["train"][0]["nifti"], nib.nifti1.Nifti1Image) + + +@require_nibabel +def test_nifti_lazy_loading(shared_datadir): + import nibabel as nib + import numpy as np + + nifti_path = str(shared_datadir / "test_nifti.nii.gz") + nifti = Nifti() + encoded_example = nifti.encode_example(nifti_path) + decoded_example = nifti.decode_example(encoded_example) + + # Verify that the data object is an ArrayProxy (lazy) and not a numpy array (dense) + assert nib.is_proxy(decoded_example.dataobj) + assert not isinstance(decoded_example.dataobj, np.ndarray) + + # Verify that we can still access the data + data = decoded_example.get_fdata() + assert data.shape == (80, 80, 10)