Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
for stream_slice_tuple in product:
partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice]
extra_fields = dict(ChainMap(*[s.extra_fields for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
if len(cursor_slices) > 1:
raise ValueError(
f"There should only be a single cursor slice. Found {cursor_slices}"
Expand All @@ -157,7 +158,9 @@ def stream_slices(self) -> Iterable[StreamSlice]:
cursor_slice = cursor_slices[0]
else:
cursor_slice = {}
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=extra_fields
)

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand Down
93 changes: 93 additions & 0 deletions unit_tests/sources/declarative/partition_routers/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
#

from typing import Any, Iterable, List, Mapping, Optional, Union

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.streams.checkpoint import Cursor
from airbyte_cdk.sources.types import Record, StreamSlice


class MockStream(DeclarativeStream):
def __init__(self, slices, records, name, cursor_field="", cursor=None):
self.config = {}
self._slices = slices
self._records = records
self._stream_cursor_field = (
InterpolatedString.create(cursor_field, parameters={})
if isinstance(cursor_field, str)
else cursor_field
)
self._name = name
self._state = {"states": []}
self._cursor = cursor

@property
def name(self) -> str:
return self._name

@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "id"

@property
def state(self) -> Mapping[str, Any]:
return self._state

@state.setter
def state(self, value: Mapping[str, Any]) -> None:
self._state = value

@property
def is_resumable(self) -> bool:
return bool(self._cursor)

def get_cursor(self) -> Optional[Cursor]:
return self._cursor

def stream_slices(
self,
*,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Optional[StreamSlice]]:
for s in self._slices:
if isinstance(s, StreamSlice):
yield s
else:
yield StreamSlice(partition=s, cursor_slice={})

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
# The parent stream's records should always be read as full refresh
assert sync_mode == SyncMode.full_refresh

if not stream_slice:
result = self._records
else:
result = [
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
for r in self._records
if r["slice"] == stream_slice["slice"]
]

yield from result

# Update the state only after reading the full slice
cursor_field = self._stream_cursor_field.eval(config=self.config)
if stream_slice and cursor_field and result:
self._state["states"].append(
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
)

def get_json_schema(self) -> Mapping[str, Any]:
return {}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
CartesianProductStreamSlicer,
ListPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
ParentStreamConfig,
SubstreamPartitionRouter,
)
from airbyte_cdk.sources.declarative.requesters.request_option import (
RequestOption,
RequestOptionType,
)
from airbyte_cdk.sources.types import StreamSlice
from unit_tests.sources.declarative.partition_routers.helpers import MockStream


@pytest.mark.parametrize(
Expand Down Expand Up @@ -171,6 +176,68 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices):
assert slices == expected_slices


@pytest.mark.parametrize(
"test_name, stream_slicers, expected_slices",
[
(
"test_single_stream_slicer",
[
SubstreamPartitionRouter(
parent_stream_configs=[
ParentStreamConfig(
stream=MockStream(
[{}],
[
{"a": {"b": 0}, "extra_field_key": "extra_field_value_0"},
{"a": {"b": 1}, "extra_field_key": "extra_field_value_1"},
{"a": {"c": 2}, "extra_field_key": "extra_field_value_2"},
{"a": {"b": 3}, "extra_field_key": "extra_field_value_3"},
],
"first_stream",
),
parent_key="a/b",
partition_field="first_stream_id",
parameters={},
config={},
extra_fields=[["extra_field_key"]],
)
],
parameters={},
config={},
),
],
[
StreamSlice(
partition={"first_stream_id": 0, "parent_slice": {}},
cursor_slice={},
extra_fields={"extra_field_key": "extra_field_value_0"},
),
StreamSlice(
partition={"first_stream_id": 1, "parent_slice": {}},
cursor_slice={},
extra_fields={"extra_field_key": "extra_field_value_1"},
),
StreamSlice(
partition={"first_stream_id": 3, "parent_slice": {}},
cursor_slice={},
extra_fields={"extra_field_key": "extra_field_value_3"},
),
],
)
],
)
def test_substream_slicer_with_extra_fields(test_name, stream_slicers, expected_slices):
slicer = CartesianProductStreamSlicer(stream_slicers=stream_slicers, parameters={})
slices = [s for s in slicer.stream_slices()]
partitions = [s.partition for s in slices]
expected_partitions = [s.partition for s in expected_slices]
assert partitions == expected_partitions

extra_fields = [s.extra_fields for s in slices]
expected_extra_fields = [s.extra_fields for s in expected_slices]
assert extra_fields == expected_extra_fields


def test_stream_slices_raises_exception_if_multiple_cursor_slice_components():
stream_slicers = [
DatetimeBasedCursor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airbyte_cdk.sources.streams.checkpoint import Cursor
from airbyte_cdk.sources.types import Record, StreamSlice
from airbyte_cdk.utils import AirbyteTracedException
from unit_tests.sources.declarative.partition_routers.helpers import MockStream

parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}]
more_records = [
Expand Down Expand Up @@ -63,88 +64,6 @@
)


class MockStream(DeclarativeStream):
def __init__(self, slices, records, name, cursor_field="", cursor=None):
self.config = {}
self._slices = slices
self._records = records
self._stream_cursor_field = (
InterpolatedString.create(cursor_field, parameters={})
if isinstance(cursor_field, str)
else cursor_field
)
self._name = name
self._state = {"states": []}
self._cursor = cursor

@property
def name(self) -> str:
return self._name

@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "id"

@property
def state(self) -> Mapping[str, Any]:
return self._state

@state.setter
def state(self, value: Mapping[str, Any]) -> None:
self._state = value

@property
def is_resumable(self) -> bool:
return bool(self._cursor)

def get_cursor(self) -> Optional[Cursor]:
return self._cursor

def stream_slices(
self,
*,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Optional[StreamSlice]]:
for s in self._slices:
if isinstance(s, StreamSlice):
yield s
else:
yield StreamSlice(partition=s, cursor_slice={})

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
# The parent stream's records should always be read as full refresh
assert sync_mode == SyncMode.full_refresh

if not stream_slice:
result = self._records
else:
result = [
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
for r in self._records
if r["slice"] == stream_slice["slice"]
]

yield from result

# Update the state only after reading the full slice
cursor_field = self._stream_cursor_field.eval(config=self.config)
if stream_slice and cursor_field and result:
self._state["states"].append(
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
)

def get_json_schema(self) -> Mapping[str, Any]:
return {}


class MockIncrementalStream(MockStream):
def __init__(self, slices, records, name, cursor_field="", cursor=None, date_ranges=None):
super().__init__(slices, records, name, cursor_field, cursor)
Expand Down
Loading