Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions mmai25_hackathon/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torch.utils.data import Dataset, Sampler
from torch_geometric.data import DataLoader

from .load_data.cxr import load_chest_xray_image, load_mimic_cxr_metadata
from .load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list
from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list

__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"]


Expand Down Expand Up @@ -101,13 +105,190 @@ class CXRDataset(BaseDataset):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.records = load_mimic_cxr_metadata(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self):
return len(self.records)

def __getitem__(self, sample_ID: int):
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
path = idx["cxr_path"]
image = load_chest_xray_image(path)
item = {"image": image, "subject_id": record_idx["subject_id"]}
samples.append(item)
return samples

def modality(self) -> str:
"""Return the modality of the dataset."""
return "CXR"


class ECGDataset(BaseDataset):
"""Example subclass for an ECG dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Loading the ECG data (records contains patient id, hea path) in df frame
self.records = load_mimic_iv_ecg_record_list(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self):
return len(self.records)

def __getitem__(self, sample_ID: int):
"""Return samples for one sampleID from the dataset."""
# record_idx = self.records[idx]
# signals, fields = load_ecg_record(record_idx["hea_path"])
# return {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]}
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
signals, fields = load_ecg_record(idx["hea_path"])
item = {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]}
samples.append(item)
return samples

def __repr__(self) -> str:
"""Return a string representation of the dataset."""
return f"{self.__class__.__name__}({self.extra_repr()})"

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return f"sample_size={len(self.subject_ids)}"

def modality(self) -> str:
"""Return the modality of the dataset."""
return "ECG"

def __add__(self, other):
"""
Combine with another dataset. Assume other is a single sample.

Override in subclasses to implement multimodal aggregation.

Args:
other: Another dataset to combine with this one.

Initial Idea:
Use `__add__` to align and merge heterogeneous modalities into a single
dataset, keeping shared IDs synchronized.
Note: This is not mandatory; treat it as a sketch you can refine or replace.
"""
self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer")
self.subject_ids = self.records["subject_id"].tolist()

# TODO: takes a single sample from other, find corresponding sample in this dataset?
# i.e. find idx where sample_id matches from other and call get_item on all of those indices?

return self


class EchoDataset(BaseDataset):
"""Example subclass for an ECHO dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.records = load_mimic_iv_echo_record_list(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.records)

def __getitem__(self, sample_ID: int):
"""Return a single sample from the dataset."""
# record = self.records.iloc[idx]
# sample_path = record["echo_path"]
# frames, meta = load_echo_dicom(sample_path)
# return {"frames": frames, "metadata": meta, "subject_id": record["subject_id"]}
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
sample_path = idx["echo_path"]
frames, meta = load_echo_dicom(sample_path)
item = {"frames": frames, "metadata": meta, "subject_id": idx["subject_id"]}
samples.append(item)
return samples

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}"

def modality(self) -> str:
"""Return the modality of the dataset."""
return "echo"

def __add__(self, other):
"""
Combine with another dataset.

Override in subclasses to implement multimodal aggregation.

Args:
other: Another dataset to combine with this one.

Initial Idea:
Use `__add__` to align and merge heterogeneous modalities into a single
dataset, keeping shared IDs synchronized.
Note: This is not mandatory; treat it as a sketch you can refine or replace.
"""
self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer")
self.subject_ids = self.records["subject_id"].tolist()
return self


class MultimodalDataset(BaseDataset):
"""Example subclass for a multimodal dataset."""

def __init__(self, datasets: list[BaseDataset], *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = datasets
# _dataset = datasets[0]
# if not isinstance(_dataset, BaseDataset):
# raise ValueError("All elements in datasets must be instances of BaseDataset.")
# if len(datasets) > 1:
# for ds in datasets[1:]:
# if not isinstance(ds, BaseDataset):
# raise ValueError("All elements in datasets must be instances of BaseDataset.")
# _dataset.__add__(ds)
# self.dataset = _dataset

# get union of all subject IDs in each dataset
self.subject_ids = list(set().union(*(ds.subject_ids for ds in datasets)))
print(f"MultimodalDataset initialized with {len(self.subject_ids)} unique subjects.")

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.dataset)

def __getitem__(self, idx: int):
"""Return a single sample from the dataset."""
subject_ID = self.subject_ids[idx]
results = {}
for dataset in self.datasets:
items = dataset.__getitem__(subject_ID)
results[dataset.modality()] = items
return results

# get dictionaries for each dataset
# results = {}
# for ds in self.datasets:
# dict_result = ds.__getitem__(idx) # TODO: assumes idx is same for each sample. replace idx with sample ID
# results[ds.modality()] = dict_result

# or primary dataset and __add__ in others
# primary_ds = self.datasets[0]
# item = primary_ds.__getitem__(idx)
# for ds in self.datasets[1:]:
# items = ds.__add__(item)
# return results

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return self.dataset.extra_repr()


class BaseDataLoader(DataLoader):
Expand All @@ -132,6 +313,20 @@ class BaseDataLoader(DataLoader):
Note: This is not a hard requirement. Consider it a future-facing idea you can evolve.
"""

def __init__(
self,
dataset: BaseDataset,
batch_size: int = 1,
shuffle: bool = False,
follow_batch: list = None,
exclude_keys: list = None,
**kwargs,
):
super().__init__(dataset, batch_size, shuffle, follow_batch, exclude_keys, **kwargs)

# collate_fn=lambda data_list: Batch.from_data_list(
# data_list, follow_batch),


class MultimodalDataLoader(BaseDataLoader):
"""Example dataloader for handling multiple data modalities."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Data: # type: ignore[name-defined]
return self._graphs[idx]

ds = GraphDataset()
loader = BaseDataLoader(ds, batch_size=2, shuffle=False)
ds = GraphDataset() # type: ignore[arg-type]
loader = BaseDataLoader(ds, batch_size=2, shuffle=False) # type: ignore[arg-type]

total_graphs = 0
for batch in loader:
Expand Down