Skip to content

Commit 1c9049a

Browse files
authored
fix: revert remerge concurrent cdk builder change because of flaky test (#705)
1 parent addd443 commit 1c9049a

23 files changed

+358
-635
lines changed

airbyte_cdk/connector_builder/connector_builder_handler.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#
44

55

6-
from dataclasses import asdict
7-
from typing import Any, Dict, List, Mapping, Optional
6+
from dataclasses import asdict, dataclass, field
7+
from typing import Any, ClassVar, Dict, List, Mapping
88

99
from airbyte_cdk.connector_builder.test_reader import TestReader
1010
from airbyte_cdk.models import (
@@ -15,32 +15,45 @@
1515
Type,
1616
)
1717
from airbyte_cdk.models import Type as MessageType
18-
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
19-
ConcurrentDeclarativeSource,
20-
TestLimits,
21-
)
2218
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
2319
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
20+
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
21+
ModelToComponentFactory,
22+
)
2423
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
2524
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now
2625
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2726

27+
DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
28+
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
29+
DEFAULT_MAXIMUM_RECORDS = 100
30+
DEFAULT_MAXIMUM_STREAMS = 100
31+
2832
MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice"
2933
MAX_SLICES_KEY = "max_slices"
3034
MAX_RECORDS_KEY = "max_records"
3135
MAX_STREAMS_KEY = "max_streams"
3236

3337

38+
@dataclass
39+
class TestLimits:
40+
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name
41+
42+
max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS)
43+
max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE)
44+
max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES)
45+
max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS)
46+
47+
3448
def get_limits(config: Mapping[str, Any]) -> TestLimits:
3549
command_config = config.get("__test_read_config", {})
36-
return TestLimits(
37-
max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS),
38-
max_pages_per_slice=command_config.get(
39-
MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE
40-
),
41-
max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES),
42-
max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS),
50+
max_pages_per_slice = (
51+
command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE
4352
)
53+
max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES
54+
max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS
55+
max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS
56+
return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams)
4457

4558

4659
def should_migrate_manifest(config: Mapping[str, Any]) -> bool:
@@ -62,30 +75,21 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool:
6275
return config.get("__should_normalize", False)
6376

6477

65-
def create_source(
66-
config: Mapping[str, Any],
67-
limits: TestLimits,
68-
catalog: Optional[ConfiguredAirbyteCatalog],
69-
state: Optional[List[AirbyteStateMessage]],
70-
) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]:
78+
def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource:
7179
manifest = config["__injected_declarative_manifest"]
72-
73-
# We enforce a concurrency level of 1 so that the stream is processed on a single thread
74-
# to retain ordering for the grouping of the builder message responses.
75-
if "concurrency_level" in manifest:
76-
manifest["concurrency_level"]["default_concurrency"] = 1
77-
else:
78-
manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1}
79-
80-
return ConcurrentDeclarativeSource(
81-
catalog=catalog,
80+
return ManifestDeclarativeSource(
8281
config=config,
83-
state=state,
84-
source_config=manifest,
8582
emit_connector_builder_messages=True,
83+
source_config=manifest,
8684
migrate_manifest=should_migrate_manifest(config),
8785
normalize_manifest=should_normalize_manifest(config),
88-
limits=limits,
86+
component_factory=ModelToComponentFactory(
87+
emit_connector_builder_messages=True,
88+
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
89+
limit_slices_fetched=limits.max_slices,
90+
disable_retries=True,
91+
disable_cache=True,
92+
),
8993
)
9094

9195

airbyte_cdk/connector_builder/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def handle_connector_builder_request(
9191
def handle_request(args: List[str]) -> str:
9292
command, config, catalog, state = get_config_and_catalog_from_args(args)
9393
limits = get_limits(config)
94-
source = create_source(config=config, limits=limits, catalog=catalog, state=state)
95-
return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
94+
source = create_source(config, limits)
95+
return orjson.dumps(
9696
AirbyteMessageSerializer.dump(
9797
handle_connector_builder_request(source, command, config, catalog, state, limits)
9898
)
99-
).decode()
99+
).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
100100

101101

102102
if __name__ == "__main__":

airbyte_cdk/connector_builder/test_reader/helpers.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from copy import deepcopy
77
from json import JSONDecodeError
8-
from typing import Any, Dict, List, Mapping, Optional, Union
8+
from typing import Any, Dict, List, Mapping, Optional
99

1010
from airbyte_cdk.connector_builder.models import (
1111
AuxiliaryRequest,
@@ -17,8 +17,6 @@
1717
from airbyte_cdk.models import (
1818
AirbyteLogMessage,
1919
AirbyteMessage,
20-
AirbyteStateBlob,
21-
AirbyteStateMessage,
2220
OrchestratorType,
2321
TraceType,
2422
)
@@ -468,7 +466,7 @@ def handle_current_slice(
468466
return StreamReadSlices(
469467
pages=current_slice_pages,
470468
slice_descriptor=current_slice_descriptor,
471-
state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [],
469+
state=[latest_state_message] if latest_state_message else [],
472470
auxiliary_requests=auxiliary_requests if auxiliary_requests else [],
473471
)
474472

@@ -720,23 +718,3 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore
720718
Determines the type of the auxiliary request based on the stream and HTTP properties.
721719
"""
722720
return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None))
723-
724-
725-
def convert_state_blob_to_mapping(
726-
state_message: Union[AirbyteStateMessage, Dict[str, Any]],
727-
) -> Dict[str, Any]:
728-
"""
729-
The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not
730-
a dictionary, but rather a list of kwargs fields. This in turn causes it to not be
731-
properly turned into a dictionary when translating this back into response output
732-
by the connector_builder_handler using asdict()
733-
"""
734-
735-
if isinstance(state_message, AirbyteStateMessage) and state_message.stream:
736-
state_value = state_message.stream.stream_state
737-
if isinstance(state_value, AirbyteStateBlob):
738-
state_value_mapping = {k: v for k, v in state_value.__dict__.items()}
739-
state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response
740-
return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict
741-
else:
742-
return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above

airbyte_cdk/connector_builder/test_reader/message_grouper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_message_groups(
9595
latest_state_message: Optional[Dict[str, Any]] = None
9696
slice_auxiliary_requests: List[AuxiliaryRequest] = []
9797

98-
while message := next(messages, None):
98+
while records_count < limit and (message := next(messages, None)):
9999
json_message = airbyte_message_to_json(message)
100100

101101
if is_page_http_request_for_different_stream(json_message, stream_name):

airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,11 @@ def on_partition(self, partition: Partition) -> None:
9595
"""
9696
stream_name = partition.stream_name()
9797
self._streams_to_running_partitions[stream_name].add(partition)
98-
cursor = self._stream_name_to_instance[stream_name].cursor
9998
if self._slice_logger.should_log_slice_message(self._logger):
10099
self._message_repository.emit_message(
101100
self._slice_logger.create_slice_log_message(partition.to_slice())
102101
)
103-
self._thread_pool_manager.submit(
104-
self._partition_reader.process_partition, partition, cursor
105-
)
102+
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition)
106103

107104
def on_partition_complete_sentinel(
108105
self, sentinel: PartitionCompleteSentinel
@@ -115,16 +112,26 @@ def on_partition_complete_sentinel(
115112
"""
116113
partition = sentinel.partition
117114

118-
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
119-
if partition in partitions_running:
120-
partitions_running.remove(partition)
121-
# If all partitions were generated and this was the last one, the stream is done
122-
if (
123-
partition.stream_name() not in self._streams_currently_generating_partitions
124-
and len(partitions_running) == 0
125-
):
126-
yield from self._on_stream_is_done(partition.stream_name())
127-
yield from self._message_repository.consume_queue()
115+
try:
116+
if sentinel.is_successful:
117+
stream = self._stream_name_to_instance[partition.stream_name()]
118+
stream.cursor.close_partition(partition)
119+
except Exception as exception:
120+
self._flag_exception(partition.stream_name(), exception)
121+
yield AirbyteTracedException.from_exception(
122+
exception, stream_descriptor=StreamDescriptor(name=partition.stream_name())
123+
).as_sanitized_airbyte_message()
124+
finally:
125+
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
126+
if partition in partitions_running:
127+
partitions_running.remove(partition)
128+
# If all partitions were generated and this was the last one, the stream is done
129+
if (
130+
partition.stream_name() not in self._streams_currently_generating_partitions
131+
and len(partitions_running) == 0
132+
):
133+
yield from self._on_stream_is_done(partition.stream_name())
134+
yield from self._message_repository.consume_queue()
128135

129136
def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
130137
"""

airbyte_cdk/sources/concurrent_source/concurrent_source.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import concurrent
55
import logging
66
from queue import Queue
7-
from typing import Iterable, Iterator, List, Optional
7+
from typing import Iterable, Iterator, List
88

99
from airbyte_cdk.models import AirbyteMessage
1010
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
@@ -16,7 +16,7 @@
1616
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
1717
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
1818
from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
19-
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader
19+
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
2020
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
2121
from airbyte_cdk.sources.streams.concurrent.partitions.types import (
2222
PartitionCompleteSentinel,
@@ -43,7 +43,6 @@ def create(
4343
logger: logging.Logger,
4444
slice_logger: SliceLogger,
4545
message_repository: MessageRepository,
46-
queue: Optional[Queue[QueueItem]] = None,
4746
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
4847
) -> "ConcurrentSource":
4948
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
@@ -60,21 +59,19 @@ def create(
6059
logger,
6160
)
6261
return ConcurrentSource(
63-
threadpool=threadpool,
64-
logger=logger,
65-
slice_logger=slice_logger,
66-
queue=queue,
67-
message_repository=message_repository,
68-
initial_number_partitions_to_generate=initial_number_of_partitions_to_generate,
69-
timeout_seconds=timeout_seconds,
62+
threadpool,
63+
logger,
64+
slice_logger,
65+
message_repository,
66+
initial_number_of_partitions_to_generate,
67+
timeout_seconds,
7068
)
7169

7270
def __init__(
7371
self,
7472
threadpool: ThreadPoolManager,
7573
logger: logging.Logger,
7674
slice_logger: SliceLogger = DebugSliceLogger(),
77-
queue: Optional[Queue[QueueItem]] = None,
7875
message_repository: MessageRepository = InMemoryMessageRepository(),
7976
initial_number_partitions_to_generate: int = 1,
8077
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
@@ -94,36 +91,33 @@ def __init__(
9491
self._initial_number_partitions_to_generate = initial_number_partitions_to_generate
9592
self._timeout_seconds = timeout_seconds
9693

97-
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
98-
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
99-
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
100-
# information and might even need to be configurable depending on the source
101-
self._queue = queue or Queue(maxsize=10_000)
102-
10394
def read(
10495
self,
10596
streams: List[AbstractStream],
10697
) -> Iterator[AirbyteMessage]:
10798
self._logger.info("Starting syncing")
99+
100+
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
101+
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
102+
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
103+
# information and might even need to be configurable depending on the source
104+
queue: Queue[QueueItem] = Queue(maxsize=10_000)
108105
concurrent_stream_processor = ConcurrentReadProcessor(
109106
streams,
110-
PartitionEnqueuer(self._queue, self._threadpool),
107+
PartitionEnqueuer(queue, self._threadpool),
111108
self._threadpool,
112109
self._logger,
113110
self._slice_logger,
114111
self._message_repository,
115-
PartitionReader(
116-
self._queue,
117-
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
118-
),
112+
PartitionReader(queue),
119113
)
120114

121115
# Enqueue initial partition generation tasks
122116
yield from self._submit_initial_partition_generators(concurrent_stream_processor)
123117

124118
# Read from the queue until all partitions were generated and read
125119
yield from self._consume_from_queue(
126-
self._queue,
120+
queue,
127121
concurrent_stream_processor,
128122
)
129123
self._threadpool.check_for_errors_and_shutdown()
@@ -147,10 +141,7 @@ def _consume_from_queue(
147141
airbyte_message_or_record_or_exception,
148142
concurrent_stream_processor,
149143
)
150-
# In the event that a partition raises an exception, anything remaining in
151-
# the queue will be missed because is_done() can raise an exception and exit
152-
# out of this loop before remaining items are consumed
153-
if queue.empty() and concurrent_stream_processor.is_done():
144+
if concurrent_stream_processor.is_done() and queue.empty():
154145
# all partitions were generated and processed. we're done here
155146
break
156147

@@ -170,7 +161,5 @@ def _handle_item(
170161
yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item)
171162
elif isinstance(queue_item, Record):
172163
yield from concurrent_stream_processor.on_record(queue_item)
173-
elif isinstance(queue_item, AirbyteMessage):
174-
yield queue_item
175164
else:
176165
raise ValueError(f"Unknown queue item type: {type(queue_item)}")

0 commit comments

Comments
 (0)