Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
title: Create a document dataset
- local: nifti_dataset
title: Create a medical imaging dataset
- local: bids_dataset
title: Load a BIDS dataset
title: "Vision"
- sections:
- local: nlp_load
Expand Down
63 changes: 63 additions & 0 deletions docs/source/bids_dataset.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# BIDS Dataset

[BIDS (Brain Imaging Data Structure)](https://bids.neuroimaging.io/) is a standard for organizing and describing neuroimaging and behavioral data. The `datasets` library supports loading BIDS datasets directly, leveraging `pybids` for parsing and `nibabel` for handling NIfTI files.

<Tip>

To use the BIDS loader, you need to install the `bids` extra (which installs `pybids` and `nibabel`):

```bash
pip install datasets[bids]
```

</Tip>

## Loading a BIDS Dataset

You can load a BIDS dataset by pointing to its root directory (containing `dataset_description.json`):

```python
from datasets import load_dataset

# Load a local BIDS dataset
ds = load_dataset("bids", data_dir="/path/to/bids/dataset")

# Access the first example
print(ds["train"][0])
# {
# 'subject': '01',
# 'session': 'baseline',
# 'datatype': 'anat',
# 'suffix': 'T1w',
# 'nifti': <nibabel.nifti1.Nifti1Image>,
# ...
# }
```

The `nifti` column contains `nibabel` image objects, which can be visualized interactively in Jupyter notebooks.

## Filtering

You can filter the dataset by BIDS entities like `subject`, `session`, and `datatype` when loading:

```python
# Load only specific subjects and datatypes
ds = load_dataset(
"bids",
data_dir="/path/to/bids/dataset",
subjects=["01", "05", "10"],
sessions=["pre", "post"],
datatypes=["func"],
)
```

## Metadata

BIDS datasets often include JSON sidecar files with metadata (e.g., scanner parameters). This metadata is loaded into the `metadata` column as a JSON string.

```python
import json

metadata = json.loads(ds["train"][0]["metadata"])
print(metadata["RepetitionTime"])
```
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@

NIBABEL_REQUIRE = ["nibabel>=5.3.2", "ipyniivue==2.4.2"]

PYBIDS_REQUIRE = ["pybids>=0.21.0"] + NIBABEL_REQUIRE

EXTRAS_REQUIRE = {
"audio": AUDIO_REQUIRE,
"vision": VISION_REQUIRE,
Expand All @@ -228,6 +230,7 @@
"docs": DOCS_REQUIRE,
"pdfs": PDFS_REQUIRE,
"nibabel": NIBABEL_REQUIRE,
"bids": PYBIDS_REQUIRE,
}

setup(
Expand Down
1 change: 1 addition & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None
NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None
PYBIDS_AVAILABLE = importlib.util.find_spec("bids") is not None

# Optional compression tools
RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .arrow import arrow
from .audiofolder import audiofolder
from .bids import bids
from .cache import cache
from .csv import csv
from .hdf5 import hdf5
Expand Down Expand Up @@ -48,6 +49,7 @@ def _hash_python_lines(lines: list[str]) -> str:
"videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())),
"pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())),
"niftifolder": (niftifolder.__name__, _hash_python_lines(inspect.getsource(niftifolder).splitlines())),
"bids": (bids.__name__, _hash_python_lines(inspect.getsource(bids).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/bids/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bids import Bids, BidsConfig
116 changes: 116 additions & 0 deletions src/datasets/packaged_modules/bids/bids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
import os
from dataclasses import dataclass
from typing import Optional

import datasets
from datasets import config


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class BidsConfig(datasets.BuilderConfig):
"""BuilderConfig for BIDS datasets."""

data_dir: Optional[str] = None
database_path: Optional[str] = None # For pybids caching
subjects: Optional[list[str]] = None # Filter by subject
sessions: Optional[list[str]] = None # Filter by session
datatypes: Optional[list[str]] = None # Filter by datatype


class Bids(datasets.GeneratorBasedBuilder):
"""BIDS dataset loader using pybids."""

BUILDER_CONFIG_CLASS = BidsConfig

def _info(self):
if not config.PYBIDS_AVAILABLE:
raise ImportError("To load BIDS datasets, please install pybids: pip install pybids")
if not config.NIBABEL_AVAILABLE:
raise ImportError("To load BIDS datasets, please install nibabel: pip install nibabel")

return datasets.DatasetInfo(
features=datasets.Features(
{
"subject": datasets.Value("string"),
"session": datasets.Value("string"),
"datatype": datasets.Value("string"),
"suffix": datasets.Value("string"),
"task": datasets.Value("string"),
"run": datasets.Value("string"),
"path": datasets.Value("string"),
"nifti": datasets.Nifti(),
"metadata": datasets.Value("string"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be something for another PR but actually having a dict-like object here would be more beneficial here. Not quite sure how we could achieve that, maybe through pyarrow's mapping and union type or having a dedicated feature for BIDSMetadata (or for dictionaries in general?).

}
)
)

def _split_generators(self, dl_manager):
from bids import BIDSLayout

if not self.config.data_dir:
raise ValueError("data_dir is required for BIDS datasets")

if not os.path.isdir(self.config.data_dir):
raise ValueError(f"data_dir does not exist: {self.config.data_dir}")

desc_file = os.path.join(self.config.data_dir, "dataset_description.json")
if not os.path.exists(desc_file):
raise ValueError(f"Not a valid BIDS dataset: missing dataset_description.json in {self.config.data_dir}")

layout = BIDSLayout(
self.config.data_dir,
database_path=self.config.database_path,
validate=False, # Don't fail on minor validation issues
)

# Build query kwargs
query = {"extension": [".nii", ".nii.gz"]}
if self.config.subjects:
query["subject"] = self.config.subjects
if self.config.sessions:
query["session"] = self.config.sessions
if self.config.datatypes:
query["datatype"] = self.config.datatypes

# Get all NIfTI files
nifti_files = layout.get(**query)

if not nifti_files:
logger.warning(
f"No NIfTI files found in {self.config.data_dir} with filters: {query}. "
"Check that the dataset is valid BIDS and filters match existing data."
)

return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"layout": layout, "files": nifti_files},
)
]

def _generate_examples(self, layout, files):
for idx, bids_file in enumerate(files):
entities = bids_file.get_entities()

# Get JSON sidecar metadata
metadata = layout.get_metadata(bids_file.path)
metadata_str = json.dumps(metadata) if metadata else "{}"

yield (
idx,
{
"subject": entities.get("subject"),
"session": entities.get("session"),
"datatype": entities.get("datatype"),
"suffix": entities.get("suffix"),
"task": entities.get("task"),
"run": str(entities.get("run")) if entities.get("run") else None,
"path": bids_file.path,
"nifti": bids_file.path,
"metadata": metadata_str,
},
)
120 changes: 120 additions & 0 deletions tests/packaged_modules/test_bids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json

import numpy as np
import pytest

import datasets.config


@pytest.fixture
def minimal_bids_dataset(tmp_path):
"""Minimal valid BIDS dataset with one subject, one T1w scan."""
# dataset_description.json (required)
(tmp_path / "dataset_description.json").write_text(
json.dumps({"Name": "Test BIDS Dataset", "BIDSVersion": "1.10.1"})
)

# Create subject/anat folder
anat_dir = tmp_path / "sub-01" / "anat"
anat_dir.mkdir(parents=True)

# Create dummy NIfTI
if datasets.config.NIBABEL_AVAILABLE:
import nibabel as nib

data = np.zeros((4, 4, 4), dtype=np.float32)
img = nib.Nifti1Image(data, np.eye(4))
nib.save(img, str(anat_dir / "sub-01_T1w.nii.gz"))
else:
# Fallback if nibabel not available (shouldn't happen in test env ideally)
(anat_dir / "sub-01_T1w.nii.gz").write_bytes(b"DUMMY NIFTI CONTENT")

# JSON sidecar
(anat_dir / "sub-01_T1w.json").write_text(json.dumps({"RepetitionTime": 2.0}))

return str(tmp_path)


@pytest.fixture
def multi_subject_bids(tmp_path):
"""BIDS dataset with multiple subjects and sessions."""
(tmp_path / "dataset_description.json").write_text(
json.dumps({"Name": "Multi-Subject Test", "BIDSVersion": "1.10.1"})
)

data = np.zeros((4, 4, 4), dtype=np.float32)

if datasets.config.NIBABEL_AVAILABLE:
import nibabel as nib
else:
nib = None

for sub in ["01", "02"]:
for ses in ["baseline", "followup"]:
anat_dir = tmp_path / f"sub-{sub}" / f"ses-{ses}" / "anat"
anat_dir.mkdir(parents=True)

file_path = anat_dir / f"sub-{sub}_ses-{ses}_T1w.nii.gz"
if nib:
img = nib.Nifti1Image(data, np.eye(4))
nib.save(img, str(file_path))
else:
file_path.write_bytes(b"DUMMY NIFTI CONTENT")

(anat_dir / f"sub-{sub}_ses-{ses}_T1w.json").write_text(json.dumps({"RepetitionTime": 2.0}))

return str(tmp_path)


def test_bids_module_imports():
from datasets.packaged_modules.bids import Bids, BidsConfig

assert Bids is not None
assert BidsConfig is not None


def test_bids_requires_pybids(monkeypatch):
"""Test helpful error when pybids not installed."""
from datasets.packaged_modules.bids.bids import Bids

monkeypatch.setattr(datasets.config, "PYBIDS_AVAILABLE", False)

with pytest.raises(ImportError, match="pybids"):
Bids()


@pytest.mark.skipif(
not datasets.config.PYBIDS_AVAILABLE or not datasets.config.NIBABEL_AVAILABLE,
reason="pybids or nibabel not installed",
)
def test_bids_loads_single_subject(minimal_bids_dataset):
from datasets import load_dataset

ds = load_dataset("bids", data_dir=minimal_bids_dataset)

assert "train" in ds
assert len(ds["train"]) == 1

sample = ds["train"][0]
assert sample["subject"] == "01"
assert sample["suffix"] == "T1w"
assert sample["datatype"] == "anat"
assert sample["session"] is None


@pytest.mark.skipif(
not datasets.config.PYBIDS_AVAILABLE or not datasets.config.NIBABEL_AVAILABLE,
reason="pybids or nibabel not installed",
)
def test_bids_multi_subject(multi_subject_bids):
from datasets import load_dataset

ds = load_dataset("bids", data_dir=multi_subject_bids)

assert len(ds["train"]) == 4 # 2 subjects × 2 sessions

subjects = {sample["subject"] for sample in ds["train"]}
assert subjects == {"01", "02"}

sessions = {sample["session"] for sample in ds["train"]}
assert sessions == {"baseline", "followup"}