Skip to content

Commit 203c131

Browse files
committed
Refactor to store only unique states
1 parent a8db6b6 commit 203c131

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,36 @@ def close_partition(self, partition: Partition) -> None:
166166

167167
def _check_and_update_parent_state(self) -> None:
168168
"""
169-
If all slices for the earliest partitions are closed, pop them from the left
170-
of _partition_parent_state_map and update _parent_state to the most recent popped.
169+
Pop the leftmost partition state from _partition_parent_state_map only if
170+
*all partitions* up to (and including) that partition key in _semaphore_per_partition
171+
are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
171172
"""
172173
last_closed_state = None
173-
# We iterate in creation order (left to right) in the OrderedDict
174-
for p_key in list(self._partition_parent_state_map.keys()):
175-
# If this partition is not fully closed, stop
176-
if p_key not in self._finished_partitions or self._semaphore_per_partition[p_key]._value != 0:
174+
175+
while self._partition_parent_state_map:
176+
# Look at the earliest partition key in creation order
177+
earliest_key = next(iter(self._partition_parent_state_map))
178+
179+
# Verify ALL partitions from the left up to earliest_key are finished
180+
all_left_finished = True
181+
for p_key, sem in self._semaphore_per_partition.items():
182+
# If any earlier partition is still not finished, we must stop
183+
if p_key not in self._finished_partitions or sem._value != 0:
184+
all_left_finished = False
185+
break
186+
# Once we've reached earliest_key in the semaphore order, we can stop checking
187+
if p_key == earliest_key:
188+
break
189+
190+
# If the partitions up to earliest_key are not all finished, break the while-loop
191+
if not all_left_finished:
177192
break
178-
# Otherwise, we pop from the left
193+
194+
# Otherwise, pop the leftmost entry from parent-state map
179195
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
180196
last_closed_state = closed_parent_state
181197

182-
# If we popped at least one partition, update the parent_state to that partition's parent state
198+
# Update _parent_state if we actually popped at least one partition
183199
if last_closed_state is not None:
184200
self._parent_state = last_closed_state
185201

@@ -228,11 +244,13 @@ def stream_slices(self) -> Iterable[StreamSlice]:
228244
slices = self._partition_router.stream_slices()
229245
self._timer.start()
230246
for partition, last, parent_state in iterate_with_last_flag_and_state(
231-
slices, self._partition_router.get_stream_state
247+
slices, self._partition_router.get_stream_state
232248
):
233249
yield from self._generate_slices_from_partition(partition, parent_state)
234250

235-
def _generate_slices_from_partition(self, partition: StreamSlice, parent_state: Mapping[str, Any]) -> Iterable[StreamSlice]:
251+
def _generate_slices_from_partition(
252+
self, partition: StreamSlice, parent_state: Mapping[str, Any]
253+
) -> Iterable[StreamSlice]:
236254
# Ensure the maximum number of partitions is not exceeded
237255
self._ensure_partition_limit()
238256

@@ -247,12 +265,17 @@ def _generate_slices_from_partition(self, partition: StreamSlice, parent_state:
247265
with self._lock:
248266
self._number_of_partitions += 1
249267
self._cursor_per_partition[partition_key] = cursor
250-
self._semaphore_per_partition[partition_key] = (
251-
threading.Semaphore(0)
252-
)
268+
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
253269

254270
with self._lock:
255-
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
271+
if (
272+
len(self._partition_parent_state_map) == 0
273+
or self._partition_parent_state_map[
274+
next(reversed(self._partition_parent_state_map))
275+
]
276+
!= parent_state
277+
):
278+
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
256279

257280
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
258281
cursor.stream_slices(),
@@ -287,7 +310,6 @@ def _ensure_partition_limit(self) -> None:
287310
self._use_global_cursor = True
288311

289312
with self._lock:
290-
self._number_of_partitions += 1
291313
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
292314
# Try removing finished partitions first
293315
for partition_key in list(self._cursor_per_partition.keys()):
@@ -372,9 +394,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
372394
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
373395
self._create_cursor(state["cursor"])
374396
)
375-
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = (
376-
threading.Semaphore(0)
377-
)
378397

379398
# set default state for missing partitions if it is per partition with fallback to global
380399
if self._GLOBAL_STATE_KEY in stream_state:

unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,6 +2027,8 @@ def test_incremental_parent_state_no_records(
20272027
"cursor": {"updated_at": PARENT_COMMENT_CURSOR_PARTITION_1},
20282028
}
20292029
],
2030+
"state": {},
2031+
"use_global_cursor": False,
20302032
"parent_state": {"posts": {"updated_at": PARENT_POSTS_CURSOR}},
20312033
}
20322034
},

0 commit comments

Comments
 (0)