Skip to content
68 changes: 50 additions & 18 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AlwaysAvailableAvailabilityStrategy,
)
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, FinalStateCursor
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream

Expand Down Expand Up @@ -193,31 +194,44 @@ def _group_streams(
declarative_stream.name
].get("incremental_sync")

if (
is_without_partition_router_nor_cursor = not bool(
datetime_based_cursor_component_definition
and datetime_based_cursor_component_definition.get("type", "")
== DatetimeBasedCursorModel.__name__
and self._stream_supports_concurrent_partition_processing(
declarative_stream=declarative_stream
) and not (
name_to_stream_mapping[declarative_stream.name]
.get("retriever", {})
.get("partition_router")
)
is_datetime_incremental_without_partition_routing = (
self._is_datetime_incremental_without_partition_routing(
datetime_based_cursor_component_definition, declarative_stream
)
and hasattr(declarative_stream.retriever, "stream_slicer")
and isinstance(declarative_stream.retriever.stream_slicer, DatetimeBasedCursor)
)
if (
is_without_partition_router_nor_cursor
or is_datetime_incremental_without_partition_routing
):
stream_state = state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)

cursor, connector_state_converter = (
self._constructor.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=datetime_based_cursor_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
if is_datetime_incremental_without_partition_routing:
cursor: Cursor = (
self._constructor.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=datetime_based_cursor_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
)
)
else:
cursor = FinalStateCursor(
declarative_stream.name,
declarative_stream.namespace,
self.message_repository,
)
)

partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
Expand All @@ -240,7 +254,9 @@ def _group_streams(
json_schema=declarative_stream.get_json_schema(),
availability_strategy=AlwaysAvailableAvailabilityStrategy(),
primary_key=get_primary_key_from_stream(declarative_stream.primary_key),
cursor_field=cursor.cursor_field.cursor_field_key,
cursor_field=cursor.cursor_field.cursor_field_key
if hasattr(cursor, "cursor_field")
else None,
logger=self.logger,
cursor=cursor,
)
Expand All @@ -252,6 +268,22 @@ def _group_streams(

return concurrent_streams, synchronous_streams

def _is_datetime_incremental_without_partition_routing(
self,
datetime_based_cursor_component_definition: Mapping[str, Any],
declarative_stream: DeclarativeStream,
) -> bool:
return (
bool(datetime_based_cursor_component_definition)
and datetime_based_cursor_component_definition.get("type", "")
== DatetimeBasedCursorModel.__name__
and self._stream_supports_concurrent_partition_processing(
declarative_stream=declarative_stream
)
and hasattr(declarative_stream.retriever, "stream_slicer")
and isinstance(declarative_stream.retriever.stream_slicer, DatetimeBasedCursor)
)

def _stream_supports_concurrent_partition_processing(
self, declarative_stream: DeclarativeStream
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Mapping,
MutableMapping,
Optional,
Tuple,
Type,
Union,
get_args,
Expand Down Expand Up @@ -753,7 +752,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
config: Config,
stream_state: MutableMapping[str, Any],
**kwargs: Any,
) -> Tuple[ConcurrentCursor, DateTimeStreamStateConverter]:
) -> ConcurrentCursor:
component_type = component_definition.get("type")
if component_definition.get("type") != model_type.__name__:
raise ValueError(
Expand Down Expand Up @@ -884,23 +883,20 @@ def create_concurrent_cursor_from_datetime_based_cursor(
if evaluated_step:
step_length = parse_duration(evaluated_step)

return (
ConcurrentCursor(
stream_name=stream_name,
stream_namespace=stream_namespace,
stream_state=stream_state,
message_repository=self._message_repository, # type: ignore # message_repository is always instantiated with a value by factory
connector_state_manager=state_manager,
connector_state_converter=connector_state_converter,
cursor_field=cursor_field,
slice_boundary_fields=slice_boundary_fields,
start=start_date, # type: ignore # Having issues w/ inspection for GapType and CursorValueType as shown in existing tests. Confirmed functionality is working in practice
end_provider=end_date_provider, # type: ignore # Having issues w/ inspection for GapType and CursorValueType as shown in existing tests. Confirmed functionality is working in practice
lookback_window=lookback_window,
slice_range=step_length,
cursor_granularity=cursor_granularity,
),
connector_state_converter,
return ConcurrentCursor(
stream_name=stream_name,
stream_namespace=stream_namespace,
stream_state=stream_state,
message_repository=self._message_repository, # type: ignore # message_repository is always instantiated with a value by factory
connector_state_manager=state_manager,
connector_state_converter=connector_state_converter,
cursor_field=cursor_field,
slice_boundary_fields=slice_boundary_fields,
start=start_date, # type: ignore # Having issues w/ inspection for GapType and CursorValueType as shown in existing tests. Confirmed functionality is working in practice
end_provider=end_date_provider, # type: ignore # Having issues w/ inspection for GapType and CursorValueType as shown in existing tests. Confirmed functionality is working in practice
lookback_window=lookback_window,
slice_range=step_length,
cursor_granularity=cursor_granularity,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3055,7 +3055,7 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(
"lookback_window": "P3D",
}

concurrent_cursor, stream_state_converter = (
concurrent_cursor = (
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
Expand Down Expand Up @@ -3087,6 +3087,7 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(
assert concurrent_cursor._end_provider() == expected_end
assert concurrent_cursor._concurrent_state == expected_concurrent_state

stream_state_converter = concurrent_cursor._connector_state_converter
assert isinstance(stream_state_converter, CustomFormatConcurrentStreamStateConverter)
assert stream_state_converter._datetime_format == expected_datetime_format
assert stream_state_converter._is_sequential_state
Expand Down Expand Up @@ -3187,7 +3188,7 @@ def test_create_concurrent_cursor_from_datetime_based_cursor(
stream_state={},
)
else:
concurrent_cursor, stream_state_converter = (
concurrent_cursor = (
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
Expand Down Expand Up @@ -3244,7 +3245,7 @@ def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined():
"lookback_window": "P3D",
}

concurrent_cursor, stream_state_converter = (
concurrent_cursor = (
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,20 +479,19 @@ def test_group_streams():
synchronous_streams = source._synchronous_streams

# 2 incremental streams
assert len(concurrent_streams) == 2
concurrent_stream_0, concurrent_stream_1 = concurrent_streams
assert len(concurrent_streams) == 3
concurrent_stream_0, concurrent_stream_1, concurrent_stream_2 = concurrent_streams
assert isinstance(concurrent_stream_0, DefaultStream)
assert concurrent_stream_0.name == "party_members"
assert isinstance(concurrent_stream_1, DefaultStream)
assert concurrent_stream_1.name == "locations"
assert concurrent_stream_1.name == "palaces"
assert isinstance(concurrent_stream_2, DefaultStream)
assert concurrent_stream_2.name == "locations"

# 1 full refresh stream, 1 substream
assert len(synchronous_streams) == 2
synchronous_stream_0, synchronous_stream_1 = synchronous_streams
assert isinstance(synchronous_stream_0, DeclarativeStream)
assert synchronous_stream_0.name == "palaces"
assert isinstance(synchronous_stream_1, DeclarativeStream)
assert synchronous_stream_1.name == "party_members_skills"
assert len(synchronous_streams) == 1
assert isinstance(synchronous_streams[0], DeclarativeStream)
assert synchronous_streams[0].name == "party_members_skills"


@freezegun.freeze_time(time_to_freeze=datetime(2024, 9, 1, 0, 0, 0, 0, tzinfo=timezone.utc))
Expand Down Expand Up @@ -539,7 +538,7 @@ def test_create_concurrent_cursor():
assert party_members_cursor._lookback_window == timedelta(days=5)
assert party_members_cursor._cursor_granularity == timedelta(days=1)

locations_stream = source._concurrent_streams[1]
locations_stream = source._concurrent_streams[2]
assert isinstance(locations_stream, DefaultStream)
locations_cursor = locations_stream.cursor

Expand Down Expand Up @@ -754,7 +753,7 @@ def test_read_with_concurrent_and_synchronous_streams():
assert len(palaces_states) == 1
assert (
palaces_states[0].stream.stream_state.__dict__
== AirbyteStateBlob(__ab_full_refresh_sync_complete=True).__dict__
== AirbyteStateBlob(__ab_no_cursor_state_message=True).__dict__
)

# Expects 3 records, 3 slices, 3 records in slice
Expand Down Expand Up @@ -1275,8 +1274,8 @@ def test_streams_with_stream_state_interpolation_should_be_synchronous():
state=None,
)

assert len(source._concurrent_streams) == 0
assert len(source._synchronous_streams) == 4
assert len(source._concurrent_streams) == 1
assert len(source._synchronous_streams) == 3


def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurrent():
Expand Down Expand Up @@ -1571,5 +1570,6 @@ def get_states_for_stream(


def disable_emitting_sequential_state_messages(source: ConcurrentDeclarativeSource) -> None:
for concurrent_streams in source._concurrent_streams: # type: ignore # This is the easiest way to disable behavior from the test
concurrent_streams.cursor._connector_state_converter._is_sequential_state = False # type: ignore # see above
for concurrent_stream in source._concurrent_streams: # type: ignore # This is the easiest way to disable behavior from the test
if isinstance(concurrent_stream.cursor, ConcurrentCursor):
concurrent_stream.cursor._connector_state_converter._is_sequential_state = False # type: ignore # see above
Loading