Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2827,13 +2827,9 @@ def create_record_selector(
else None
)

if model.transform_before_filtering is None:
# default to False if not set
model.transform_before_filtering = False

assert model.transform_before_filtering is not None # for mypy

transform_before_filtering = model.transform_before_filtering
transform_before_filtering = (
False if model.transform_before_filtering is None else model.transform_before_filtering
)
if client_side_incremental_sync:
record_filter = ClientSideIncrementalRecordFilterDecorator(
config=config,
Expand All @@ -2843,7 +2839,11 @@ def create_record_selector(
else None,
**client_side_incremental_sync,
)
transform_before_filtering = True
transform_before_filtering = (
True
if model.transform_before_filtering is None
else model.transform_before_filtering
)

if model.schema_normalization is None:
# default to no schema normalization if not set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import freezegun
import isodate
import pytest
from typing_extensions import deprecated

from airbyte_cdk.models import (
Expand Down Expand Up @@ -1876,6 +1877,69 @@ def test_stream_using_is_client_side_incremental_has_cursor_state():
assert client_side_incremental_cursor_state == expected_cursor_value


@pytest.mark.parametrize(
"expected_transform_before_filtering",
[
pytest.param(
True,
id="transform before filtering",
),
pytest.param(
False,
id="transform after filtering",
),
pytest.param(
None,
id="default transform before filtering",
),
],
)
def test_stream_using_is_client_side_incremental_has_transform_before_filtering_according_to_manifest(
expected_transform_before_filtering,
):
expected_cursor_value = "2024-07-01"
state = [
AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="locations", namespace=None),
stream_state=AirbyteStateBlob(updated_at=expected_cursor_value),
),
)
]

manifest_with_stream_state_interpolation = copy.deepcopy(_MANIFEST)

# Enable semi-incremental on the locations stream
manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["incremental_sync"][
"is_client_side_incremental"
] = True

if expected_transform_before_filtering is not None:
manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["retriever"][
"record_selector"
]["transform_before_filtering"] = expected_transform_before_filtering

source = ConcurrentDeclarativeSource(
source_config=manifest_with_stream_state_interpolation,
config=_CONFIG,
catalog=_CATALOG,
state=state,
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

locations_stream = concurrent_streams[2]
assert isinstance(locations_stream, DefaultStream)

simple_retriever = locations_stream._stream_partition_generator._partition_factory._retriever
record_selector = simple_retriever.record_selector

if expected_transform_before_filtering is not None:
assert record_selector.transform_before_filtering == expected_transform_before_filtering
else:
assert record_selector.transform_before_filtering is True


def create_wrapped_stream(stream: DeclarativeStream) -> Stream:
slice_to_records_mapping = get_mocked_read_records_output(stream_name=stream.name)

Expand Down
Loading