Skip to content

poc: connector builder using concurrent cdk #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions airbyte_cdk/connector_builder/connector_builder_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Mapping
from typing import Any, Dict, List, Mapping, Optional

from airbyte_cdk.connector_builder.test_reader import TestReader
from airbyte_cdk.models import (
Expand All @@ -15,6 +15,14 @@
Type,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE,
DEFAULT_MAXIMUM_NUMBER_OF_SLICES,
DEFAULT_MAXIMUM_RECORDS,
DEFAULT_MAXIMUM_STREAMS,
ConcurrentDeclarativeSource,
TestLimits,
)
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
Expand All @@ -24,25 +32,12 @@
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
DEFAULT_MAXIMUM_RECORDS = 100
DEFAULT_MAXIMUM_STREAMS = 100

MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice"
MAX_SLICES_KEY = "max_slices"
MAX_RECORDS_KEY = "max_records"
MAX_STREAMS_KEY = "max_streams"


@dataclass
class TestLimits:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move this to avoid circular dependencies. I assume this was caused by concurrent_declarative_source.py having to know about TestLimits but connector_builder_handler.py having to know about concurrent_declarative_source

max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS)
max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE)
max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES)
max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS)


def get_limits(config: Mapping[str, Any]) -> TestLimits:
command_config = config.get("__test_read_config", {})
max_pages_per_slice = (
Expand All @@ -54,19 +49,24 @@ def get_limits(config: Mapping[str, Any]) -> TestLimits:
return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams)


def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource:
def _ensure_concurrency_level(manifest: Dict[str, Any]) -> None:
# We need to do that to ensure that the state in the StreamReadSlices only contains the changes for one slice
# Note that this is below the _LOWEST_SAFE_CONCURRENCY_LEVEL but it is fine in this case because we are limiting the number of slices
# being generated which means that the memory usage is limited anyway
if "concurrency_level" not in manifest:
manifest["concurrency_level"] = {"type": "ConcurrencyLevel"}
manifest["concurrency_level"]["default_concurrency"] = 1

def create_source(config: Mapping[str, Any], limits: TestLimits, catalog: Optional[ConfiguredAirbyteCatalog] = None, state: Any = None) -> ManifestDeclarativeSource:
manifest = config["__injected_declarative_manifest"]
return ManifestDeclarativeSource(
_ensure_concurrency_level(manifest)
return ConcurrentDeclarativeSource(
config=config,
emit_connector_builder_messages=True,
catalog=catalog,
state=state,
source_config=manifest,
component_factory=ModelToComponentFactory(
emit_connector_builder_messages=True,
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
limit_slices_fetched=limits.max_slices,
disable_retries=True,
disable_cache=True,
),
emit_connector_builder_messages=True,
limits=limits,
)


Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def handle_connector_builder_request(
def handle_request(args: List[str]) -> str:
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
source = create_source(config, limits, catalog, state)
return orjson.dumps(
AirbyteMessageSerializer.dump(
handle_connector_builder_request(source, command, config, catalog, state, limits)
Expand Down
43 changes: 14 additions & 29 deletions airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
partition_enqueuer: PartitionEnqueuer,
thread_pool_manager: ThreadPoolManager,
logger: logging.Logger,
slice_logger: SliceLogger,
message_repository: MessageRepository,
partition_reader: PartitionReader,
):
Expand All @@ -44,7 +43,6 @@ def __init__(
:param partition_enqueuer: PartitionEnqueuer instance
:param thread_pool_manager: ThreadPoolManager instance
:param logger: Logger instance
:param slice_logger: SliceLogger instance
:param message_repository: MessageRepository instance
:param partition_reader: PartitionReader instance
"""
Expand All @@ -59,7 +57,6 @@ def __init__(
self._stream_instances_to_start_partition_generation = stream_instances_to_read_from
self._streams_currently_generating_partitions: List[str] = []
self._logger = logger
self._slice_logger = slice_logger
self._message_repository = message_repository
self._partition_reader = partition_reader
self._streams_done: Set[str] = set()
Expand Down Expand Up @@ -95,11 +92,7 @@ def on_partition(self, partition: Partition) -> None:
"""
stream_name = partition.stream_name()
self._streams_to_running_partitions[stream_name].add(partition)
if self._slice_logger.should_log_slice_message(self._logger):
self._message_repository.emit_message(
self._slice_logger.create_slice_log_message(partition.to_slice())
)
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition)
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition, self._stream_name_to_instance[partition.stream_name()].cursor)

def on_partition_complete_sentinel(
self, sentinel: PartitionCompleteSentinel
Expand All @@ -112,26 +105,19 @@ def on_partition_complete_sentinel(
"""
partition = sentinel.partition

try:
if sentinel.is_successful:
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved stream.cursor.close_partition(partition) to PartitionReader which meant that there were no need for catching exceptions here

except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
exception, stream_descriptor=StreamDescriptor(name=partition.stream_name())
).as_sanitized_airbyte_message()
finally:
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
if partition in partitions_running:
partitions_running.remove(partition)
# If all partitions were generated and this was the last one, the stream is done
if (
partition.stream_name() not in self._streams_currently_generating_partitions
and len(partitions_running) == 0
):
yield from self._on_stream_is_done(partition.stream_name())
yield from self._message_repository.consume_queue()
if sentinel.is_successful:
stream = self._stream_name_to_instance[partition.stream_name()]

partitions_running = self._streams_to_running_partitions[partition.stream_name()]
if partition in partitions_running:
partitions_running.remove(partition)
# If all partitions were generated and this was the last one, the stream is done
if (
partition.stream_name() not in self._streams_currently_generating_partitions
and len(partitions_running) == 0
):
yield from self._on_stream_is_done(partition.stream_name())
yield from self._message_repository.consume_queue()

def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
"""
Expand Down Expand Up @@ -160,7 +146,6 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING
)
self._record_counter[stream.name] += 1
stream.cursor.observe(record)
yield message
yield from self._message_repository.consume_queue()

Expand Down
18 changes: 11 additions & 7 deletions airbyte_cdk/sources/concurrent_source/concurrent_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import concurrent
import logging
from queue import Queue
from typing import Iterable, Iterator, List
from typing import Iterable, Iterator, List, Optional

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
Expand All @@ -16,7 +16,7 @@
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.types import (
PartitionCompleteSentinel,
Expand Down Expand Up @@ -44,6 +44,7 @@ def create(
slice_logger: SliceLogger,
message_repository: MessageRepository,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
queue: Optional[Queue[QueueItem]] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the MessageRepository also needs to have access to the queue, we need to have the queue passed here instead of being created

) -> "ConcurrentSource":
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
too_many_generator = (
Expand All @@ -65,6 +66,7 @@ def create(
message_repository,
initial_number_of_partitions_to_generate,
timeout_seconds,
queue,
)

def __init__(
Expand All @@ -75,6 +77,7 @@ def __init__(
message_repository: MessageRepository = InMemoryMessageRepository(),
initial_number_partitions_to_generate: int = 1,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
queue: Optional[Queue[QueueItem]] = None,
) -> None:
"""
:param threadpool: The threadpool to submit tasks to
Expand All @@ -90,6 +93,7 @@ def __init__(
self._message_repository = message_repository
self._initial_number_partitions_to_generate = initial_number_partitions_to_generate
self._timeout_seconds = timeout_seconds
self._queue = queue if queue else Queue(maxsize=10_000)

def read(
self,
Expand All @@ -101,23 +105,21 @@ def read(
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
# information and might even need to be configurable depending on the source
queue: Queue[QueueItem] = Queue(maxsize=10_000)
concurrent_stream_processor = ConcurrentReadProcessor(
streams,
PartitionEnqueuer(queue, self._threadpool),
PartitionEnqueuer(self._queue, self._threadpool),
self._threadpool,
self._logger,
self._slice_logger,
self._message_repository,
PartitionReader(queue),
PartitionReader(self._queue, PartitionLogger(self._slice_logger, self._logger, self._message_repository)),
)

# Enqueue initial partition generation tasks
yield from self._submit_initial_partition_generators(concurrent_stream_processor)

# Read from the queue until all partitions were generated and read
yield from self._consume_from_queue(
queue,
self._queue,
concurrent_stream_processor,
)
self._threadpool.check_for_errors_and_shutdown()
Expand Down Expand Up @@ -161,5 +163,7 @@ def _handle_item(
yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item)
elif isinstance(queue_item, Record):
yield from concurrent_stream_processor.on_record(queue_item)
elif isinstance(queue_item, AirbyteMessage):
yield queue_item
else:
raise ValueError(f"Unknown queue item type: {type(queue_item)}")
Loading
Loading