Skip to content

Commit 45aa98d

Browse files
committed
Allow for streams using AsyncRetriever and DatetimeBasedCursor to perform checkpointing
1 parent b5ed82c commit 45aa98d

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

airbyte_cdk/sources/declarative/declarative_stream.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
4+
45
import logging
56
from dataclasses import InitVar, dataclass, field
67
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
@@ -13,7 +14,7 @@
1314
)
1415
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
1516
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
16-
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
17+
from airbyte_cdk.sources.declarative.retrievers import AsyncRetriever, SimpleRetriever
1718
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
1819
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
1920
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
@@ -189,7 +190,10 @@ def state_checkpoint_interval(self) -> Optional[int]:
189190
return None
190191

191192
def get_cursor(self) -> Optional[Cursor]:
192-
if self.retriever and isinstance(self.retriever, SimpleRetriever):
193+
if self.retriever and (
194+
isinstance(self.retriever, SimpleRetriever)
195+
or isinstance(self.retriever, AsyncRetriever)
196+
):
193197
return self.retriever.cursor
194198
return None
195199

airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
AsyncJobOrchestrator,
99
AsyncPartition,
1010
)
11+
from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor
12+
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
1113
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
1214
SinglePartitionRouter,
1315
)
@@ -35,6 +37,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
3537
self._job_orchestrator_factory = self.job_orchestrator_factory
3638
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
3739
self._parameters = parameters
40+
if isinstance(self.stream_slicer, DatetimeBasedCursor):
41+
self._cursor: Optional[DeclarativeCursor] = self.stream_slicer
42+
else:
43+
self._cursor = None
44+
45+
@property
46+
def cursor(self) -> Optional[DeclarativeCursor]:
47+
return self._cursor
3848

3949
def stream_slices(self) -> Iterable[StreamSlice]:
4050
slices = self.stream_slicer.stream_slices()

airbyte_cdk/sources/declarative/retrievers/async_retriever.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from airbyte_cdk.models import FailureType
1010
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition
1111
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
12+
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
1213
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
1314
AsyncJobPartitionRouter,
1415
)
1516
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
1617
from airbyte_cdk.sources.source import ExperimentalClassWarning
1718
from airbyte_cdk.sources.streams.core import StreamData
18-
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
19+
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
1920
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2021

2122

@@ -35,27 +36,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
3536

3637
@property
3738
def state(self) -> StreamState:
38-
"""
39-
As a first iteration for sendgrid, there is no state to be managed
40-
"""
41-
return {}
42-
43-
@state.setter
44-
def state(self, value: StreamState) -> None:
45-
"""
46-
As a first iteration for sendgrid, there is no state to be managed
47-
"""
48-
pass
49-
50-
def _get_stream_state(self) -> StreamState:
5139
"""
5240
Gets the current state of the stream.
5341
5442
Returns:
5543
StreamState: Mapping[str, Any]
5644
"""
45+
return self.stream_slicer.cursor.get_stream_state() if self.stream_slicer.cursor else {}
46+
47+
@state.setter
48+
def state(self, value: StreamState) -> None:
49+
"""State setter, accept state serialized by state getter."""
50+
if self.stream_slicer.cursor:
51+
self.stream_slicer.cursor.set_initial_state(value)
5752

58-
return self.state
53+
@property
54+
def cursor(self) -> Optional[DeclarativeCursor]:
55+
return self.stream_slicer.cursor
5956

6057
def _validate_and_get_stream_slice_partition(
6158
self, stream_slice: Optional[StreamSlice] = None
@@ -88,13 +85,47 @@ def read_records(
8885
records_schema: Mapping[str, Any],
8986
stream_slice: Optional[StreamSlice] = None,
9087
) -> Iterable[StreamData]:
91-
stream_state: StreamState = self._get_stream_state()
88+
_slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check
89+
90+
stream_state: StreamState = self.state
9291
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
9392
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
93+
most_recent_record_from_slice = None
9494

95-
yield from self.record_selector.filter_and_transform(
95+
for stream_data in self.record_selector.filter_and_transform(
9696
all_data=records,
9797
stream_state=stream_state,
9898
records_schema=records_schema,
99-
stream_slice=stream_slice,
100-
)
99+
stream_slice=_slice,
100+
):
101+
if self.cursor and stream_data:
102+
self.cursor.observe(_slice, stream_data)
103+
104+
most_recent_record_from_slice = self._get_most_recent_record(
105+
most_recent_record_from_slice, stream_data, _slice
106+
)
107+
yield stream_data
108+
109+
if self.cursor:
110+
# DatetimeBasedCursor doesn't expect a partition field, but for AsyncRetriever streams this will
111+
# be the slice range
112+
slice_no_partition = StreamSlice(cursor_slice=_slice.cursor_slice, partition={})
113+
self.cursor.close_slice(slice_no_partition, most_recent_record_from_slice)
114+
115+
def _get_most_recent_record(
116+
self,
117+
current_most_recent: Optional[Record],
118+
current_record: Optional[Record],
119+
stream_slice: StreamSlice,
120+
) -> Optional[Record]:
121+
if self.cursor and current_record:
122+
if not current_most_recent:
123+
return current_record
124+
else:
125+
return (
126+
current_most_recent
127+
if self.cursor.is_greater_than_or_equal(current_most_recent, current_record)
128+
else current_record
129+
)
130+
else:
131+
return None

0 commit comments

Comments
 (0)