Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,5 @@ __marimo__/

# Streamlit
.streamlit/secrets.toml

mimic-iv/
Empty file added group_code/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions group_code/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# gpt generated code
def combine_dataframes(dfs, tags, on="subject_id"):
"""
Combine multiple dataframes on a common column, prefixing each dataframe's
columns (except the merge column) with a given tag.
Keeps only matched IDs across all dataframes (inner join).
"""
assert len(dfs) == len(tags), "Number of dataframes must match number of tags"

# Rename columns in each dataframe (except merge key)
renamed_dfs = []
for df, tag in zip(dfs, tags):
df_renamed = df.rename(columns={col: f"{tag}_{col}" for col in df.columns if col != on})
renamed_dfs.append(df_renamed)

# Iteratively merge with inner join
combined = renamed_dfs[0]
for df in renamed_dfs[1:]:
combined = combined.merge(df, on=on, how="inner")

return combined
46 changes: 46 additions & 0 deletions group_code/mm_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from group_code.helper import combine_dataframes
from group_code.uni_model.ecg import ECG_uni
from group_code.uni_model.echo import ECHO_uni
from group_code.uni_model.ehr import EHR_uni
from mmai25_hackathon.dataset import BaseDataset


class EEE_dataset(BaseDataset):
def __init__(self, data_roots, data_mods):
self.uni_dict = {}
for mod, r_path in zip(data_mods, data_roots):
if mod == "ecg":
self.uni_dict[mod] = ECG_uni(r_path)
elif mod == "ehr":
self.uni_dict[mod] = EHR_uni(r_path)
elif mod == "echo":
self.uni_dict[mod] = ECHO_uni(r_path)
else:
print("Modality not supported.")

self.combined_records = combine_dataframes(
[val.records for key, val in self.uni_dict.items() if key != "ehr"],
[key for key, val in self.uni_dict.items() if key != "ehr"],
)

def __len__(self) -> int:
return len(self.combined_records)

def __getitem__(self, idx: int):
subject_id = self.combined_records.iloc[idx]["subject_id"]
return_dict = {}
for key, val in self.uni_dict.items():
return_dict[key] = val.fetch(subject_id)
return return_dict

def __add__(self, data_roots, data_mods):
for mod, r_path in zip(data_mods, data_roots):
if mod == "ecg":
self.uni_dict[mod] = ECG_uni(r_path)
self.combined_records = combine_dataframes(
[val.records for key, val in self.uni_dict.items() if key != "ehr"],
[key for key, val in self.uni_dict.items() if key != "ehr"],
)

def get_idx_from_sub_id(self, subject_id):
return self.combined_records.index[self.combined_records["subject_id"] == subject_id][0]
24 changes: 24 additions & 0 deletions group_code/uni_model/ecg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch.utils.data import Dataset

from mmai25_hackathon.load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list


class ECG_uni(Dataset):
def __init__(self, mod_root):
self.records = load_mimic_iv_ecg_record_list(mod_root)

def __len__(self) -> int:
return len(self.records)

def __getitem__(self, idx: int):
subject_id = self.records.iloc[idx]["subject_id"]
sig, fields = load_ecg_record(self.records.iloc[idx]["hea_path"])

return {subject_id: [sig, fields]}

def get_idx_by_subject(self, subject_id):
return self.records.index[self.records["subject_id"] == subject_id][0]

def fetch(self, subject_id):
idx = self.get_idx_by_subject(subject_id)
return self.__getitem__(idx)
27 changes: 27 additions & 0 deletions group_code/uni_model/echo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# noqa: F403
from torch.utils.data import Dataset

from mmai25_hackathon.load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list


class ECHO_uni(Dataset):
def __init__(self, mod_root):
self.records = load_mimic_iv_echo_record_list(mod_root)

def __len__(self) -> int:
return len(self.records)

def __getitem__(self, idx: int):
subject_id = self.records.iloc[idx]["subject_id"]
frames, meta = load_echo_dicom(self.records.iloc[idx]["echo_path"])
meta_filtered = {
k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta
}
return {subject_id: meta_filtered}

def get_idx_by_subject(self, subject_id):
return self.records.index[self.records["subject_id"] == subject_id][0]

def fetch(self, subject_id):
idx = self.get_idx_by_subject(subject_id)
return self.__getitem__(idx)
30 changes: 30 additions & 0 deletions group_code/uni_model/ehr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch.utils.data import Dataset

from mmai25_hackathon.load_data.ehr import load_mimic_iv_ehr


class EHR_uni(Dataset):
def __init__(self, mod_root):
self.root = mod_root

def __len__(self) -> int:
return len(self.records)

def __getitem__(self, subject_id):
return self.fetch(subject_id)

def fetch(self, subject_id):
dfs_new = load_mimic_iv_ehr(
ehr_path=self.root,
module="both",
tables=["icustays", "admissions"],
index_cols=["subject_id", "hadm_id"],
subset_cols={
"icustays": ["first_careunit"],
"admissions": ["admittime"],
},
filter_rows={"subject_id": [int(subject_id)]},
merge=True,
join="inner",
)
return {subject_id: dfs_new}
32 changes: 32 additions & 0 deletions group_test/test_102.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from group_code.mm_dataset import EEE_dataset
from group_code.uni_model.ecg import ECG_uni
from group_code.uni_model.echo import ECHO_uni
from group_code.uni_model.ehr import EHR_uni

ecg_root = "mimic-iv/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/"
echo_root = "mimic-iv/mimic-iv-echo-0.1.physionet.org/"
ehr_root = "mimic-iv/mimic-iv-3.1/"


def test_by_id(id):
ds = EEE_dataset([ecg_root, ehr_root, echo_root], ["ecg", "ehr", "echo"])

ds_idx = ds.get_idx_from_sub_id(id)
full_results = ds[ds_idx]

ecg = ECG_uni(ecg_root)
ehr = EHR_uni(ehr_root)
echo = ECHO_uni(echo_root)

ecg_res = ecg.fetch(id)
ehr_res = ehr.fetch(id)
echo_res = echo.fetch(id)

assert full_results["ecg"][id][1]["comments"] == ecg_res[id][1]["comments"]
assert full_results["echo"][id]["Rows"] == echo_res[id]["Rows"]
assert len(full_results["ehr"][id]) == len(ehr_res[id])


if __name__ == "__main__":
test_by_id(102)
test_by_id(101)
Loading