Skip to content

Commit 3356d74

Browse files
authored
Fix embed storage nifti (#7853)
* WIP: allow uploading of nifti * remove debug statements and fix test * remove debug statements * remove debug statements
1 parent 91f96a0 commit 3356d74

File tree

2 files changed

+88
-10
lines changed

2 files changed

+88
-10
lines changed

src/datasets/features/nifti.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..download.download_config import DownloadConfig
1111
from ..table import array_cast
1212
from ..utils.file_utils import is_local_path, xopen
13-
from ..utils.py_utils import string_to_dict
13+
from ..utils.py_utils import no_op_if_value_is_null, string_to_dict
1414

1515

1616
if TYPE_CHECKING:
@@ -125,9 +125,6 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
125125
Returns:
126126
`nibabel.Nifti1Image` objects
127127
"""
128-
if not self.decode:
129-
raise NotImplementedError("Decoding is disabled for this feature. Please use Nifti(decode=True) instead.")
130-
131128
if config.NIBABEL_AVAILABLE:
132129
import nibabel as nib
133130
else:
@@ -141,6 +138,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
141138
if path is None:
142139
raise ValueError(f"A nifti should have one of 'path' or 'bytes' but both are None in {value}.")
143140
else:
141+
# gzipped files have the structure: 'gzip://T1.nii::<local_path>'
142+
if path.startswith("gzip://") and is_local_path(path.split("::")[-1]):
143+
path = path.split("::")[-1]
144144
if is_local_path(path):
145145
nifti = nib.load(path)
146146
else:
@@ -150,11 +150,10 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
150150
if source_url.startswith(config.HF_ENDPOINT)
151151
else config.HUB_DATASETS_HFFS_URL
152152
)
153-
try:
154-
repo_id = string_to_dict(source_url, pattern)["repo_id"]
155-
token = token_per_repo_id.get(repo_id)
156-
except ValueError:
157-
token = None
153+
source_url_fields = string_to_dict(source_url, pattern)
154+
token = (
155+
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
156+
)
158157
download_config = DownloadConfig(token=token)
159158
with xopen(path, "rb", download_config=download_config) as f:
160159
nifti = nib.load(f)
@@ -172,6 +171,46 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
172171

173172
return nifti
174173

174+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
175+
"""Embed NifTI files into the Arrow array.
176+
177+
Args:
178+
storage (`pa.StructArray`):
179+
PyArrow array to embed.
180+
181+
Returns:
182+
`pa.StructArray`: Array in the NifTI arrow storage type, that is
183+
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
184+
"""
185+
if token_per_repo_id is None:
186+
token_per_repo_id = {}
187+
188+
@no_op_if_value_is_null
189+
def path_to_bytes(path):
190+
source_url = path.split("::")[-1]
191+
pattern = (
192+
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
193+
)
194+
source_url_fields = string_to_dict(source_url, pattern)
195+
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
196+
download_config = DownloadConfig(token=token)
197+
with xopen(path, "rb", download_config=download_config) as f:
198+
return f.read()
199+
200+
bytes_array = pa.array(
201+
[
202+
(path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None
203+
for x in storage.to_pylist()
204+
],
205+
type=pa.binary(),
206+
)
207+
path_array = pa.array(
208+
[os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()],
209+
type=pa.string(),
210+
)
211+
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
212+
return array_cast(storage, self.pa_type)
213+
175214
def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]:
176215
"""If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary."""
177216
from .features import Value

tests/features/test_nifti.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from pathlib import Path
44

5+
import pyarrow as pa
56
import pytest
67

7-
from datasets import Dataset, Features, Nifti
8+
from datasets import Dataset, Features, Nifti, load_dataset
89
from src.datasets.features.nifti import encode_nibabel_image
910

1011
from ..utils import require_nibabel
@@ -89,3 +90,41 @@ def test_encode_nibabel_image(shared_datadir):
8990
assert isinstance(encoded_example_bytes, dict)
9091
assert encoded_example_bytes["bytes"] is not None and encoded_example_bytes["path"] is None
9192
# this cannot be converted back from bytes (yet)
93+
94+
95+
@require_nibabel
96+
def test_embed_storage(shared_datadir):
97+
from io import BytesIO
98+
99+
import nibabel as nib
100+
101+
nifti_path = str(shared_datadir / "test_nifti.nii")
102+
img = nib.load(nifti_path)
103+
nifti = Nifti()
104+
105+
bytes_array = pa.array([None], type=pa.binary())
106+
path_array = pa.array([nifti_path], type=pa.string())
107+
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"])
108+
109+
embedded_storage = nifti.embed_storage(storage)
110+
111+
embedded_bytes = embedded_storage[0]["bytes"].as_py()
112+
113+
bio = BytesIO(embedded_bytes)
114+
fh = nib.FileHolder(fileobj=bio)
115+
nifti_img = nib.Nifti1Image.from_file_map({"header": fh, "image": fh})
116+
117+
assert embedded_bytes is not None
118+
assert nifti_img.header == img.header
119+
assert (nifti_img.affine == img.affine).all()
120+
assert (nifti_img.get_fdata() == img.get_fdata()).all()
121+
122+
123+
@require_nibabel
124+
def test_load_zipped_file_locally(shared_datadir):
125+
import nibabel as nib
126+
127+
nifti_path = str(shared_datadir / "test_nifti.nii.gz")
128+
129+
ds = load_dataset("niftifolder", data_files=nifti_path)
130+
assert isinstance(ds["train"][0]["nifti"], nib.nifti1.Nifti1Image)

0 commit comments

Comments
 (0)