Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
707a6c6
Add API Budget
tolik0 Feb 5, 2025
b6bcdd7
Refactor to move api_budget to root level
tolik0 Feb 6, 2025
040ff9e
Format
tolik0 Feb 6, 2025
824d2c6
Merge branch 'main' into tolik0/add-api-budget
tolik0 Feb 7, 2025
15f830c
Update for backward compatibility
tolik0 Feb 7, 2025
1285668
Add unit tests
tolik0 Feb 9, 2025
7be9842
Add FixedWindowCallRatePolicy unit test
tolik0 Feb 9, 2025
8d3bfce
Change the partitions limit to 1000
tolik0 Feb 10, 2025
509ea05
Refactored switching logic
tolik0 Feb 10, 2025
8d44150
Increase the limit for number of partitions in memory
tolik0 Feb 10, 2025
b3f9897
Merge branch 'tolik0/add-api-budget-limit-1000' into tolik0/refactor-…
tolik0 Feb 11, 2025
342375c
Refactor ConcurrentPerPartitionCursor to not use ConcurrentCursor wit…
tolik0 Feb 12, 2025
05f4db7
Delete code from another branch
tolik0 Feb 12, 2025
c0bc645
Fix cursor value from record
tolik0 Feb 12, 2025
52b95e3
Add throttling for state emitting in ConcurrentPerPartitionCursor
tolik0 Feb 13, 2025
1166a7a
Fix unit tests
tolik0 Feb 17, 2025
4a7d9ec
Move switching to global logic
tolik0 Feb 17, 2025
19ad269
Revert test limits
tolik0 Feb 17, 2025
667700f
Merge branch 'main' into tolik0/refactor-concurrent-global-cursor
tolik0 Feb 17, 2025
6498528
Fix format
tolik0 Feb 17, 2025
d3e7fe2
Add parent state updates
tolik0 Feb 17, 2025
7b4964e
Move acquiring the semaphore
tolik0 Feb 17, 2025
8617cc8
Merge branch 'tolik0/refactor-concurrent-global-cursor' into tolik0/c…
tolik0 Feb 17, 2025
a8db6b6
Merge branch 'main' into tolik0/concurrent-perpartition-add-parent-st…
tolik0 Feb 18, 2025
203c131
Refactor to store only unique states
tolik0 Feb 18, 2025
671fab4
Add intermediate states validation to unit tests
tolik0 Feb 18, 2025
a1d98fb
Fix format
tolik0 Feb 18, 2025
eff25ec
Add unit tests
tolik0 Feb 19, 2025
c51f840
Update unit tests
tolik0 Feb 21, 2025
4a18954
Add deleting finished semaphores
tolik0 Feb 21, 2025
a7ece97
Delete testing prints
tolik0 Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import logging
import threading
import time
from collections import OrderedDict
from copy import deepcopy
from datetime import timedelta
Expand Down Expand Up @@ -58,7 +59,8 @@ class ConcurrentPerPartitionCursor(Cursor):
CurrentPerPartitionCursor expects the state of the ConcurrentCursor to follow the format {cursor_field: cursor_value}.
"""

DEFAULT_MAX_PARTITIONS_NUMBER = 10000
DEFAULT_MAX_PARTITIONS_NUMBER = 25_000
SWITCH_TO_GLOBAL_LIMIT = 10_000
_NO_STATE: Mapping[str, Any] = {}
_NO_CURSOR_STATE: Mapping[str, Any] = {}
_GLOBAL_STATE_KEY = "state"
Expand Down Expand Up @@ -93,15 +95,21 @@ def __init__(
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, ConcurrentCursor] = OrderedDict()
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()

# Parent-state tracking: store each partition’s parent state in creation order
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()

self._finished_partitions: set[str] = set()
self._lock = threading.Lock()
self._timer = Timer()
self._new_global_cursor: Optional[StreamState] = None
self._lookback_window: int = 0
self._parent_state: Optional[StreamState] = None
self._over_limit: int = 0
self._number_of_partitions: int = 0
self._use_global_cursor: bool = False
self._partition_serializer = PerPartitionKeySerializer()
# Track the last time a state message was emitted
self._last_emission_time: float = 0.0

self._set_initial_state(stream_state)

Expand Down Expand Up @@ -141,22 +149,40 @@ def close_partition(self, partition: Partition) -> None:
raise ValueError("stream_slice cannot be None")

partition_key = self._to_partition_key(stream_slice.partition)
self._cursor_per_partition[partition_key].close_partition(partition=partition)
with self._lock:
self._semaphore_per_partition[partition_key].acquire()
cursor = self._cursor_per_partition[partition_key]
if (
partition_key in self._finished_partitions
and self._semaphore_per_partition[partition_key]._value == 0
):
if not self._use_global_cursor:
self._cursor_per_partition[partition_key].close_partition(partition=partition)
cursor = self._cursor_per_partition[partition_key]
if (
self._new_global_cursor is None
or self._new_global_cursor[self.cursor_field.cursor_field_key]
< cursor.state[self.cursor_field.cursor_field_key]
partition_key in self._finished_partitions
and self._semaphore_per_partition[partition_key]._value == 0
):
self._new_global_cursor = copy.deepcopy(cursor.state)
if not self._use_global_cursor:
self._emit_state_message()
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])

self._check_and_update_parent_state()

self._emit_state_message()

self._semaphore_per_partition[partition_key].acquire()

def _check_and_update_parent_state(self) -> None:
"""
If all slices for the earliest partitions are closed, pop them from the left
of _partition_parent_state_map and update _parent_state to the most recent popped.
"""
last_closed_state = None
# We iterate in creation order (left to right) in the OrderedDict
for p_key in list(self._partition_parent_state_map.keys()):
# If this partition is not fully closed, stop
if p_key not in self._finished_partitions or self._semaphore_per_partition[p_key]._value != 0:
break
# Otherwise, we pop from the left
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
last_closed_state = closed_parent_state

# If we popped at least one partition, update the parent_state to that partition's parent state
if last_closed_state is not None:
self._parent_state = last_closed_state

def ensure_at_least_one_state_emitted(self) -> None:
"""
Expand All @@ -169,9 +195,23 @@ def ensure_at_least_one_state_emitted(self) -> None:
self._global_cursor = self._new_global_cursor
self._lookback_window = self._timer.finish()
self._parent_state = self._partition_router.get_stream_state()
self._emit_state_message()
self._emit_state_message(throttle=False)

def _emit_state_message(self) -> None:
def _throttle_state_message(self) -> Optional[float]:
"""
Throttles the state message emission to once every 60 seconds.
"""
current_time = time.time()
if current_time - self._last_emission_time <= 60:
return None
return current_time

def _emit_state_message(self, throttle: bool = True) -> None:
if throttle:
current_time = self._throttle_state_message()
if current_time is None:
return
self._last_emission_time = current_time
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
Expand All @@ -188,32 +228,39 @@ def stream_slices(self) -> Iterable[StreamSlice]:

slices = self._partition_router.stream_slices()
self._timer.start()
for partition in slices:
yield from self._generate_slices_from_partition(partition)
for partition, last, parent_state in iterate_with_last_flag_and_state(
slices, self._partition_router.get_stream_state
):
yield from self._generate_slices_from_partition(partition, parent_state)

def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
def _generate_slices_from_partition(self, partition: StreamSlice, parent_state: Mapping[str, Any]) -> Iterable[StreamSlice]:
# Ensure the maximum number of partitions is not exceeded
self._ensure_partition_limit()

partition_key = self._to_partition_key(partition.partition)

cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
cursor = self._create_cursor(
self._global_cursor,
self._lookback_window if self._global_cursor else 0,
)
with self._lock:
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
self._cursor_per_partition[partition_key] = cursor
self._semaphore_per_partition[partition_key] = (
threading.Semaphore(0)
)

with self._lock:
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
self._semaphore_per_partition[partition_key].release()
if is_last_slice:
self._finished_partitions.add(self._to_partition_key(partition.partition))
self._finished_partitions.add(partition_key)
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
Expand All @@ -232,9 +279,16 @@ def _ensure_partition_limit(self) -> None:
- Logs a warning each time a partition is removed, indicating whether it was finished
or removed due to being the oldest.
"""
if not self._use_global_cursor and self.limit_reached():
logger.info(
f"Exceeded the 'SWITCH_TO_GLOBAL_LIMIT' of {self.SWITCH_TO_GLOBAL_LIMIT}. "
f"Switching to global cursor for {self._stream_name}."
)
self._use_global_cursor = True

with self._lock:
self._number_of_partitions += 1
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
# Try removing finished partitions first
for partition_key in list(self._cursor_per_partition.keys()):
if (
Expand All @@ -245,7 +299,7 @@ def _ensure_partition_limit(self) -> None:
partition_key
) # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._over_limit}."
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
)
break
else:
Expand All @@ -254,7 +308,7 @@ def _ensure_partition_limit(self) -> None:
1
] # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
)

def _set_initial_state(self, stream_state: StreamState) -> None:
Expand Down Expand Up @@ -354,16 +408,26 @@ def _set_global_state(self, stream_state: Mapping[str, Any]) -> None:
self._new_global_cursor = deepcopy(fixed_global_state)

def observe(self, record: Record) -> None:
if not self._use_global_cursor and self.limit_reached():
self._use_global_cursor = True

if not record.associated_slice:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)

record_cursor = self._connector_state_converter.output_format(
self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
)
self._update_global_cursor(record_cursor)
if not self._use_global_cursor:
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)

def _update_global_cursor(self, value: Any) -> None:
if (
self._new_global_cursor is None
or self._new_global_cursor[self.cursor_field.cursor_field_key] < value
):
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)
Expand Down Expand Up @@ -397,4 +461,4 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
return cursor

def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from datetime import datetime, timedelta
from typing import Any, List, Mapping, MutableMapping, Optional, Union
from unittest.mock import MagicMock, patch
from urllib.parse import unquote

import pytest
Expand All @@ -18,6 +19,7 @@
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read

Expand Down Expand Up @@ -1181,14 +1183,18 @@ def test_incremental_parent_state(
initial_state,
expected_state,
):
run_incremental_parent_state_test(
manifest,
mock_requests,
expected_records,
num_intermediate_states,
initial_state,
[expected_state],
)
# Patch `_throttle_state_message` so it always returns a float (indicating "no throttle")
with patch.object(
ConcurrentPerPartitionCursor, "_throttle_state_message", return_value=9999999.0
):
run_incremental_parent_state_test(
manifest,
mock_requests,
expected_records,
num_intermediate_states,
initial_state,
[expected_state],
)


STATE_MIGRATION_EXPECTED_STATE = {
Expand Down Expand Up @@ -2967,3 +2973,47 @@ def test_incremental_substream_request_options_provider(
expected_records,
expected_state,
)


def test_state_throttling(mocker):
"""
Verifies that _emit_state_message does not emit a new state if less than 60s
have passed since last emission, and does emit once 60s or more have passed.
"""
cursor = ConcurrentPerPartitionCursor(
cursor_factory=MagicMock(),
partition_router=MagicMock(),
stream_name="test_stream",
stream_namespace=None,
stream_state={},
message_repository=MagicMock(),
connector_state_manager=MagicMock(),
connector_state_converter=MagicMock(),
cursor_field=MagicMock(),
)

mock_connector_manager = cursor._connector_state_manager
mock_repo = cursor._message_repository

# Set the last emission time to "0" so we can control offset from that
cursor._last_emission_time = 0

mock_time = mocker.patch("time.time")

# First attempt: only 10 seconds passed => NO emission
mock_time.return_value = 10
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_not_called()
mock_repo.emit_message.assert_not_called()

# Second attempt: 30 seconds passed => still NO emission
mock_time.return_value = 30
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_not_called()
mock_repo.emit_message.assert_not_called()

# Advance time: 70 seconds => exceed 60s => MUST emit
mock_time.return_value = 70
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_called_once()
mock_repo.emit_message.assert_called_once()
Loading