Skip to content

Commit a183c4d

Browse files
authored
Add functions to determine sample ranges of signals in multi-segment records (#403)
* Add contained_ranges method to calculate sample ranges that contain a signal, in multi-segment records * Add contained_combined_ranges function * Add tests for new ranges functions * Fix logic to account for empty segments and derive some fields in MultiRecord constructor
1 parent 14df878 commit a183c4d

File tree

4 files changed

+257
-16
lines changed

4 files changed

+257
-16
lines changed

tests/test_multi_record.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import wfdb
2+
3+
4+
class TestMultiRecordRanges:
5+
"""
6+
Test logic that deduces relevant segments/ranges for given signals.
7+
"""
8+
9+
def test_contained_ranges_simple_cases(self):
10+
record = wfdb.MultiRecord(
11+
segments=[
12+
wfdb.Record(sig_name=["I", "II"], sig_len=5),
13+
wfdb.Record(sig_name=["I", "III"], sig_len=10),
14+
],
15+
)
16+
17+
assert record.contained_ranges("I") == [(0, 15)]
18+
assert record.contained_ranges("II") == [(0, 5)]
19+
assert record.contained_ranges("III") == [(5, 15)]
20+
21+
def test_contained_ranges_variable_layout(self):
22+
record = wfdb.rdheader(
23+
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
24+
rd_segments=True,
25+
)
26+
27+
assert record.contained_ranges("II") == [
28+
(3261, 10136),
29+
(4610865, 10370865),
30+
(10528365, 14518365),
31+
]
32+
assert record.contained_ranges("V") == [
33+
(3261, 918261),
34+
(920865, 4438365),
35+
(4610865, 10370865),
36+
(10528365, 14518365),
37+
]
38+
assert record.contained_ranges("MCL1") == [
39+
(10136, 918261),
40+
(920865, 4438365),
41+
]
42+
assert record.contained_ranges("ABP") == [
43+
(14428365, 14450865),
44+
(14458365, 14495865),
45+
]
46+
47+
def test_contained_ranges_fixed_layout(self):
48+
record = wfdb.rdheader(
49+
"sample-data/multi-segment/041s/041s",
50+
rd_segments=True,
51+
)
52+
53+
for sig_name in record.sig_name:
54+
assert record.contained_ranges(sig_name) == [(0, 2000)]
55+
56+
def test_contained_combined_ranges_simple_cases(self):
57+
record = wfdb.MultiRecord(
58+
segments=[
59+
wfdb.Record(sig_name=["I", "II", "V"], sig_len=5),
60+
wfdb.Record(sig_name=["I", "III", "V"], sig_len=10),
61+
wfdb.Record(sig_name=["I", "II", "V"], sig_len=20),
62+
],
63+
)
64+
65+
assert record.contained_combined_ranges(["I", "II"]) == [
66+
(0, 5),
67+
(15, 35),
68+
]
69+
assert record.contained_combined_ranges(["II", "III"]) == []
70+
assert record.contained_combined_ranges(["I", "III"]) == [(5, 15)]
71+
assert record.contained_combined_ranges(["I", "II", "V"]) == [
72+
(0, 5),
73+
(15, 35),
74+
]
75+
76+
def test_contained_combined_ranges_variable_layout(self):
77+
record = wfdb.rdheader(
78+
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
79+
rd_segments=True,
80+
)
81+
82+
assert record.contained_combined_ranges(["II", "V"]) == [
83+
(3261, 10136),
84+
(4610865, 10370865),
85+
(10528365, 14518365),
86+
]
87+
assert record.contained_combined_ranges(["II", "MCL1"]) == []
88+
assert record.contained_combined_ranges(["II", "ABP"]) == [
89+
(14428365, 14450865),
90+
(14458365, 14495865),
91+
]
92+
assert record.contained_combined_ranges(["II", "V", "ABP"]) == [
93+
(14428365, 14450865),
94+
(14458365, 14495865),
95+
]
96+
assert (
97+
record.contained_combined_ranges(["II", "V", "MCL1", "ABP"]) == []
98+
)
99+
100+
def test_contained_combined_ranges_variable_layout(self):
101+
record = wfdb.rdheader(
102+
"sample-data/multi-segment/041s/041s",
103+
rd_segments=True,
104+
)
105+
106+
for sig_1 in record.sig_name:
107+
for sig_2 in record.sig_name:
108+
if sig_1 == sig_2:
109+
continue
110+
111+
assert record.contained_combined_ranges([sig_1, sig_2]) == [
112+
(0, 2000)
113+
]

wfdb/io/_header.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
import re
3-
from typing import List, Tuple
3+
from typing import Collection, List, Tuple
44

55
import numpy as np
66
import pandas as pd
@@ -858,7 +858,7 @@ def get_sig_name(self):
858858
"""
859859
if self.segments is None:
860860
raise Exception(
861-
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rsegment_fieldsments=True"
861+
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
862862
)
863863

864864
if self.layout == "fixed":
@@ -871,6 +871,114 @@ def get_sig_name(self):
871871

872872
return sig_name
873873

874+
def contained_ranges(self, sig_name: str) -> List[Tuple[int, int]]:
875+
"""
876+
Given a signal name, return the sample ranges that contain signal values,
877+
relative to the start of the full record. Does not account for NaNs/missing
878+
values.
879+
880+
This function is mainly useful for variable layout records, but can also be
881+
used for fixed-layout records. Only works if the headers from the individual
882+
segment records have already been read in.
883+
884+
Parameters
885+
----------
886+
sig_name : str
887+
The name of the signal to query.
888+
889+
Returns
890+
-------
891+
ranges : List[Tuple[int, int]]
892+
Tuple pairs which specify thee sample ranges in which the signal is contained.
893+
The second value of each tuple pair will be one beyond the signal index.
894+
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
895+
selection using signal[0:1000].
896+
897+
"""
898+
if self.segments is None:
899+
raise Exception(
900+
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
901+
)
902+
ranges = []
903+
seg_start = 0
904+
905+
range_start = None
906+
907+
# TODO: Add shortcut for fixed-layout records
908+
909+
# Cannot process segments only because missing segments are None
910+
# and do not contain length information.
911+
for seg_num in range(self.n_seg):
912+
seg_len = self.seg_len[seg_num]
913+
segment = self.segments[seg_num]
914+
915+
if seg_len == 0:
916+
continue
917+
918+
# Open signal range
919+
if (
920+
range_start is None
921+
and segment is not None
922+
and sig_name in segment.sig_name
923+
):
924+
range_start = seg_start
925+
# Close signal range
926+
elif range_start is not None and (
927+
segment is None or sig_name not in segment.sig_name
928+
):
929+
ranges.append((range_start, seg_start))
930+
range_start = None
931+
932+
seg_start += seg_len
933+
934+
# Account for final segment
935+
if range_start is not None:
936+
ranges.append((range_start, seg_start))
937+
938+
return ranges
939+
940+
def contained_combined_ranges(
941+
self,
942+
sig_names: Collection[str],
943+
) -> List[Tuple[int, int]]:
944+
"""
945+
Given a collection of signal name, return the sample ranges that
946+
contain all of the specified signals, relative to the start of the
947+
full record. Does not account for NaNs/missing values.
948+
949+
This function is mainly useful for variable layout records, but can also be
950+
used for fixed-layout records. Only works if the headers from the individual
951+
segment records have already been read in.
952+
953+
Parameters
954+
----------
955+
sig_names : List[str]
956+
The names of the signals to query.
957+
958+
Returns
959+
-------
960+
ranges : List[Tuple[int, int]]
961+
Tuple pairs which specify thee sample ranges in which the signal is contained.
962+
The second value of each tuple pair will be one beyond the signal index.
963+
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
964+
selection using signal[0:1000].
965+
966+
"""
967+
# TODO: Add shortcut for fixed-layout records
968+
969+
if len(sig_names) == 0:
970+
return []
971+
972+
combined_ranges = self.contained_ranges(sig_names[0])
973+
974+
if len(sig_names) > 1:
975+
for name in sig_names[1:]:
976+
combined_ranges = util.overlapping_ranges(
977+
combined_ranges, self.contained_ranges(name)
978+
)
979+
980+
return combined_ranges
981+
874982

875983
def wfdb_strptime(time_string: str) -> datetime.time:
876984
"""

wfdb/io/record.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,8 +1080,8 @@ class MultiRecord(BaseRecord, _header.MultiHeaderMixin):
10801080
`datetime.combine(base_date, base_time)`.
10811081
seg_name : str, optional
10821082
The name of the segment.
1083-
seg_len : int, optional
1084-
The length of the segment.
1083+
seg_len : List[int], optional
1084+
The length of each segment.
10851085
comments : list, optional
10861086
A list of string comments to be written to the header file.
10871087
sig_name : str, optional
@@ -1144,6 +1144,11 @@ def __init__(
11441144
self.seg_len = seg_len
11451145
self.sig_segments = sig_segments
11461146

1147+
if segments:
1148+
self.n_seg = len(segments)
1149+
if not seg_len:
1150+
self.seg_len = [segment.sig_len for segment in segments]
1151+
11471152
def wrsamp(self, write_dir=""):
11481153
"""
11491154
Write a multi-segment header, along with headers and dat files
@@ -1184,33 +1189,28 @@ def _check_segment_cohesion(self):
11841189
if self.n_seg != len(self.segments):
11851190
raise ValueError("Length of segments must match the 'n_seg' field")
11861191

1187-
for i in range(n_seg):
1188-
s = self.segments[i]
1192+
for seg_num, segment in enumerate(self.segments):
11891193

11901194
# If segment 0 is a layout specification record, check that its file names are all == '~''
1191-
if i == 0 and self.seg_len[0] == 0:
1192-
for file_name in s.file_name:
1195+
if seg_num == 0 and self.seg_len[0] == 0:
1196+
for file_name in segment.file_name:
11931197
if file_name != "~":
11941198
raise ValueError(
11951199
"Layout specification records must have all file_names named '~'"
11961200
)
11971201

11981202
# Sampling frequencies must all match the one in the master header
1199-
if s.fs != self.fs:
1203+
if segment.fs != self.fs:
12001204
raise ValueError(
12011205
"The 'fs' in each segment must match the overall record's 'fs'"
12021206
)
12031207

12041208
# Check the signal length of the segment against the corresponding seg_len field
1205-
if s.sig_len != self.seg_len[i]:
1209+
if segment.sig_len != self.seg_len[seg_num]:
12061210
raise ValueError(
1207-
"The signal length of segment "
1208-
+ str(i)
1209-
+ " does not match the corresponding segment length"
1211+
f"The signal length of segment {seg_num} does not match the corresponding segment length"
12101212
)
12111213

1212-
totalsig_len = totalsig_len + getattr(s, "sig_len")
1213-
12141214
# No need to check the sum of sig_lens from each segment object against sig_len
12151215
# Already effectively done it when checking sum(seg_len) against sig_len
12161216

wfdb/io/util.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import os
66

7-
from typing import Sequence
7+
from typing import Sequence, Tuple
88

99

1010
def lines_to_file(file_name: str, write_dir: str, lines: Sequence[str]):
@@ -99,3 +99,23 @@ def upround(x, base):
9999
100100
"""
101101
return base * math.ceil(float(x) / base)
102+
103+
104+
def overlapping_ranges(
105+
ranges_1: Tuple[int, int], ranges_2: Tuple[int, int]
106+
) -> Tuple[int, int]:
107+
"""
108+
Given two collections of integer ranges, return a list of ranges
109+
in which both input inputs overlap.
110+
111+
From: https://stackoverflow.com/q/40367461
112+
113+
Slightly modified so that if the end of one range exactly equals
114+
the start of the other range, no overlap would be returned.
115+
"""
116+
return [
117+
(max(first[0], second[0]), min(first[1], second[1]))
118+
for first in ranges_1
119+
for second in ranges_2
120+
if max(first[0], second[0]) < min(first[1], second[1])
121+
]

0 commit comments

Comments
 (0)