diff --git a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py index 8718004bf..db5f5fae7 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py +++ b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py @@ -149,6 +149,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: for stream_slice_tuple in product: partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice] + extra_fields = dict(ChainMap(*[s.extra_fields for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons if len(cursor_slices) > 1: raise ValueError( f"There should only be a single cursor slice. Found {cursor_slices}" @@ -157,7 +158,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: cursor_slice = cursor_slices[0] else: cursor_slice = {} - yield StreamSlice(partition=partition, cursor_slice=cursor_slice) + yield StreamSlice( + partition=partition, cursor_slice=cursor_slice, extra_fields=extra_fields + ) def set_initial_state(self, stream_state: StreamState) -> None: """ diff --git a/unit_tests/sources/declarative/partition_routers/helpers.py b/unit_tests/sources/declarative/partition_routers/helpers.py new file mode 100644 index 000000000..b64b40d5e --- /dev/null +++ b/unit_tests/sources/declarative/partition_routers/helpers.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Iterable, List, Mapping, Optional, Union + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.interpolation import InterpolatedString +from airbyte_cdk.sources.streams.checkpoint import Cursor +from airbyte_cdk.sources.types import Record, StreamSlice + + +class MockStream(DeclarativeStream): + def __init__(self, slices, records, name, cursor_field="", cursor=None): + self.config = {} + self._slices = slices + self._records = records + self._stream_cursor_field = ( + InterpolatedString.create(cursor_field, parameters={}) + if isinstance(cursor_field, str) + else cursor_field + ) + self._name = name + self._state = {"states": []} + self._cursor = cursor + + @property + def name(self) -> str: + return self._name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return "id" + + @property + def state(self) -> Mapping[str, Any]: + return self._state + + @state.setter + def state(self, value: Mapping[str, Any]) -> None: + self._state = value + + @property + def is_resumable(self) -> bool: + return bool(self._cursor) + + def get_cursor(self) -> Optional[Cursor]: + return self._cursor + + def stream_slices( + self, + *, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Optional[StreamSlice]]: + for s in self._slices: + if isinstance(s, StreamSlice): + yield s + else: + yield StreamSlice(partition=s, cursor_slice={}) + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + # The parent stream's records should always be read as full refresh + assert sync_mode == SyncMode.full_refresh + + if not stream_slice: + result = self._records + else: + result = [ + Record(data=r, associated_slice=stream_slice, stream_name=self.name) + for r in self._records + if r["slice"] == stream_slice["slice"] + ] + + yield from result + + # Update the state only after reading the full slice + cursor_field = self._stream_cursor_field.eval(config=self.config) + if stream_slice and cursor_field and result: + self._state["states"].append( + {cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]} + ) + + def get_json_schema(self) -> Mapping[str, Any]: + return {} diff --git a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py index 1d5981f7b..65122adc9 100644 --- a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py @@ -11,11 +11,16 @@ CartesianProductStreamSlicer, ListPartitionRouter, ) +from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( + ParentStreamConfig, + SubstreamPartitionRouter, +) from airbyte_cdk.sources.declarative.requesters.request_option import ( RequestOption, RequestOptionType, ) from airbyte_cdk.sources.types import StreamSlice +from unit_tests.sources.declarative.partition_routers.helpers import MockStream @pytest.mark.parametrize( @@ -171,6 +176,68 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices): assert slices == expected_slices +@pytest.mark.parametrize( + "test_name, stream_slicers, expected_slices", + [ + ( + "test_single_stream_slicer", + [ + SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=MockStream( + [{}], + [ + {"a": {"b": 0}, "extra_field_key": "extra_field_value_0"}, + {"a": {"b": 1}, "extra_field_key": "extra_field_value_1"}, + {"a": {"c": 2}, "extra_field_key": "extra_field_value_2"}, + {"a": {"b": 3}, "extra_field_key": "extra_field_value_3"}, + ], + "first_stream", + ), + parent_key="a/b", + partition_field="first_stream_id", + parameters={}, + config={}, + extra_fields=[["extra_field_key"]], + ) + ], + parameters={}, + config={}, + ), + ], + [ + StreamSlice( + partition={"first_stream_id": 0, "parent_slice": {}}, + cursor_slice={}, + extra_fields={"extra_field_key": "extra_field_value_0"}, + ), + StreamSlice( + partition={"first_stream_id": 1, "parent_slice": {}}, + cursor_slice={}, + extra_fields={"extra_field_key": "extra_field_value_1"}, + ), + StreamSlice( + partition={"first_stream_id": 3, "parent_slice": {}}, + cursor_slice={}, + extra_fields={"extra_field_key": "extra_field_value_3"}, + ), + ], + ) + ], +) +def test_substream_slicer_with_extra_fields(test_name, stream_slicers, expected_slices): + slicer = CartesianProductStreamSlicer(stream_slicers=stream_slicers, parameters={}) + slices = [s for s in slicer.stream_slices()] + partitions = [s.partition for s in slices] + expected_partitions = [s.partition for s in expected_slices] + assert partitions == expected_partitions + + extra_fields = [s.extra_fields for s in slices] + expected_extra_fields = [s.extra_fields for s in expected_slices] + assert extra_fields == expected_extra_fields + + def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): stream_slicers = [ DatetimeBasedCursor( diff --git a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py index 7b09e50dd..122c8dfae 100644 --- a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py @@ -34,6 +34,7 @@ from airbyte_cdk.sources.streams.checkpoint import Cursor from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.utils import AirbyteTracedException +from unit_tests.sources.declarative.partition_routers.helpers import MockStream parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}] more_records = [ @@ -63,88 +64,6 @@ ) -class MockStream(DeclarativeStream): - def __init__(self, slices, records, name, cursor_field="", cursor=None): - self.config = {} - self._slices = slices - self._records = records - self._stream_cursor_field = ( - InterpolatedString.create(cursor_field, parameters={}) - if isinstance(cursor_field, str) - else cursor_field - ) - self._name = name - self._state = {"states": []} - self._cursor = cursor - - @property - def name(self) -> str: - return self._name - - @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: - return "id" - - @property - def state(self) -> Mapping[str, Any]: - return self._state - - @state.setter - def state(self, value: Mapping[str, Any]) -> None: - self._state = value - - @property - def is_resumable(self) -> bool: - return bool(self._cursor) - - def get_cursor(self) -> Optional[Cursor]: - return self._cursor - - def stream_slices( - self, - *, - sync_mode: SyncMode, - cursor_field: List[str] = None, - stream_state: Mapping[str, Any] = None, - ) -> Iterable[Optional[StreamSlice]]: - for s in self._slices: - if isinstance(s, StreamSlice): - yield s - else: - yield StreamSlice(partition=s, cursor_slice={}) - - def read_records( - self, - sync_mode: SyncMode, - cursor_field: List[str] = None, - stream_slice: Mapping[str, Any] = None, - stream_state: Mapping[str, Any] = None, - ) -> Iterable[Mapping[str, Any]]: - # The parent stream's records should always be read as full refresh - assert sync_mode == SyncMode.full_refresh - - if not stream_slice: - result = self._records - else: - result = [ - Record(data=r, associated_slice=stream_slice, stream_name=self.name) - for r in self._records - if r["slice"] == stream_slice["slice"] - ] - - yield from result - - # Update the state only after reading the full slice - cursor_field = self._stream_cursor_field.eval(config=self.config) - if stream_slice and cursor_field and result: - self._state["states"].append( - {cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]} - ) - - def get_json_schema(self) -> Mapping[str, Any]: - return {} - - class MockIncrementalStream(MockStream): def __init__(self, slices, records, name, cursor_field="", cursor=None, date_ranges=None): super().__init__(slices, records, name, cursor_field, cursor)