Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 05092f6

Browse files
committed
Refactor code in TableSegment
1 parent a797a87 commit 05092f6

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

data_diff/table_segment.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
RECOMMENDED_CHECKSUM_DURATION = 20
1616

17+
def _make_range(col, min_value=None, max_value=None):
18+
if min_value is not None:
19+
yield min_value <= col
20+
if max_value is not None:
21+
yield col < max_value
22+
1723

1824
@dataclass
1925
class TableSegment:
@@ -78,28 +84,25 @@ def with_schema(self) -> "TableSegment":
7884

7985
return self._with_raw_schema(self.database.query_table_schema(self.table_path))
8086

81-
def _make_key_range(self):
82-
if self.min_key is not None:
83-
assert len(self.key_columns) == 1
84-
(k,) = self.key_columns
85-
yield self.min_key <= this[k]
86-
if self.max_key is not None:
87+
def _make_where(self):
88+
# min_key <= x < max_key
89+
if self.min_key is not None and self.max_key is not None:
8790
assert len(self.key_columns) == 1
8891
(k,) = self.key_columns
89-
yield this[k] < self.max_key
92+
yield from _make_range(this[k], self.min_key, self.max_key)
93+
94+
# min_update <= x < max_update
95+
yield from _make_range(this[self.update_column], self.min_update, self.max_update)
9096

91-
def _make_update_range(self):
92-
if self.min_update is not None:
93-
yield self.min_update <= this[self.update_column]
94-
if self.max_update is not None:
95-
yield this[self.update_column] < self.max_update
97+
# user-defined where
98+
yield self.where or SKIP
9699

97100
@property
98101
def source_table(self):
99102
return table(*self.table_path, schema=self._schema)
100103

101104
def make_select(self):
102-
return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP)
105+
return self.source_table.where(*self._make_where())
103106

104107
def get_values(self) -> list:
105108
"Download all the relevant values of the segment from the database"

0 commit comments

Comments
 (0)