Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,14 @@ def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str
return polling_response_context

def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice:
stream_slice = StreamSlice(
partition={},
cursor_slice={},
extra_fields={
return StreamSlice(
partition=job.job_parameters().partition,
cursor_slice=job.job_parameters().cursor_slice,
extra_fields=dict(job.job_parameters().extra_fields)
| {
"creation_response": self._get_creation_response_interpolation_context(job),
},
)
return stream_slice

def _get_download_targets(self, job: AsyncJob) -> Iterable[str]:
if not self.download_target_requester:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.types import Config, StreamSlice
from airbyte_cdk.utils.mapping_helpers import get_interpolation_context


@dataclass
Expand Down Expand Up @@ -52,8 +53,8 @@ def eval_request_inputs(
:param next_page_token: The pagination token
:return: The request inputs to set on an outgoing HTTP request
"""
kwargs = {
"stream_slice": stream_slice,
"next_page_token": next_page_token,
}
kwargs = get_interpolation_context(
stream_slice=stream_slice,
next_page_token=next_page_token,
)
return self._interpolator.eval(self.config, **kwargs) # type: ignore # self._interpolator is always initialized with a value and will not be None
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
from airbyte_cdk.utils.mapping_helpers import get_interpolation_context


@dataclass
Expand Down Expand Up @@ -51,10 +52,10 @@ def eval_request_inputs(
:param valid_value_types: A tuple of types that the interpolator should allow
:return: The request inputs to set on an outgoing HTTP request
"""
kwargs = {
"stream_slice": stream_slice,
"next_page_token": next_page_token,
}
kwargs = get_interpolation_context(
stream_slice=stream_slice,
next_page_token=next_page_token,
)
interpolated_value = self._interpolator.eval( # type: ignore # self._interpolator is always initialized with a value and will not be None
self.config,
valid_key_types=valid_key_types,
Expand Down
241 changes: 140 additions & 101 deletions unit_tests/sources/declarative/requesters/test_http_job_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import json
from typing import Optional
from unittest import TestCase
from unittest.mock import Mock

Expand All @@ -28,6 +29,8 @@
)
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
Expand All @@ -45,111 +48,12 @@
a_record_id,a_value
"""
_A_CURSOR_FOR_PAGINATION = "a-cursor-for-pagination"
_ERROR_HANDLER = DefaultErrorHandler(config=_ANY_CONFIG, parameters={})


class HttpJobRepositoryTest(TestCase):
def setUp(self) -> None:
message_repository = Mock()
error_handler = DefaultErrorHandler(config=_ANY_CONFIG, parameters={})

self._create_job_requester = HttpRequester(
name="stream <name>: create_job",
url_base=_URL_BASE,
path=_EXPORT_PATH,
error_handler=error_handler,
http_method=HttpMethod.POST,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=False,
)

self._polling_job_requester = HttpRequester(
name="stream <name>: polling",
url_base=_URL_BASE,
path=_EXPORT_PATH + "/{{creation_response['id']}}",
error_handler=error_handler,
http_method=HttpMethod.GET,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=False,
)

self._download_retriever = SimpleRetriever(
requester=HttpRequester(
name="stream <name>: fetch_result",
url_base="",
path="{{download_target}}",
error_handler=error_handler,
http_method=HttpMethod.GET,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=True,
),
record_selector=RecordSelector(
extractor=ResponseToFileExtractor({}),
record_filter=None,
transformations=[],
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
config=_ANY_CONFIG,
parameters={},
),
primary_key=None,
name="any name",
paginator=DefaultPaginator(
decoder=NoopDecoder(),
page_size_option=None,
page_token_option=RequestOption(
field_name="locator",
inject_into=RequestOptionType.request_parameter,
parameters={},
),
pagination_strategy=CursorPaginationStrategy(
cursor_value="{{ headers['Sforce-Locator'] }}",
decoder=NoopDecoder(),
config=_ANY_CONFIG,
parameters={},
),
url_base=_URL_BASE,
config=_ANY_CONFIG,
parameters={},
),
config=_ANY_CONFIG,
parameters={},
)

self._repository = AsyncHttpJobRepository(
creation_requester=self._create_job_requester,
polling_requester=self._polling_job_requester,
download_retriever=self._download_retriever,
abort_requester=None,
delete_requester=None,
status_extractor=DpathExtractor(
decoder=JsonDecoder(parameters={}),
field_path=["status"],
config={},
parameters={} or {},
),
status_mapping={
"ready": AsyncJobStatus.COMPLETED,
"failure": AsyncJobStatus.FAILED,
"pending": AsyncJobStatus.RUNNING,
},
download_target_extractor=DpathExtractor(
decoder=JsonDecoder(parameters={}),
field_path=["urls"],
config={},
parameters={} or {},
),
)
self._repository = self._create_async_job_repository()

self._http_mocker = HttpMocker()
self._http_mocker.__enter__()
Expand Down Expand Up @@ -178,6 +82,35 @@ def test_given_different_statuses_when_update_jobs_status_then_update_status_pro
self._repository.update_jobs_status([job])
assert job.status() == AsyncJobStatus.COMPLETED

def test_when_update_jobs_status_then_allow_access_to_stream_slice_information(self) -> None:
stream_slice = StreamSlice(partition={"path": "path_from_slice"}, cursor_slice={})
self._mock_create_response(_A_JOB_ID)
self._http_mocker.get(
HttpRequest(url=f"{_EXPORT_URL}/{stream_slice['path']}/{_A_JOB_ID}"),
HttpResponse(body=json.dumps({"id": _A_JOB_ID, "status": "ready"})),
)
repository = self._create_async_job_repository(
HttpRequester(
name="stream <name>: polling",
url_base=_URL_BASE,
path=_EXPORT_PATH + "/{{stream_slice['path']}}/{{creation_response['id']}}",
error_handler=_ERROR_HANDLER,
http_method=HttpMethod.GET,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=Mock(),
# this might not align with the rest of the components in async job repository but if message_repository becomes important for tests, please share this instance with the other components
use_cache=False,
stream_response=False,
)
)

job = repository.start(stream_slice)
repository.update_jobs_status([job])

assert job.status() == AsyncJobStatus.COMPLETED

def test_given_unknown_status_when_update_jobs_status_then_raise_error(self) -> None:
self._mock_create_response(_A_JOB_ID)
self._http_mocker.get(
Expand Down Expand Up @@ -277,3 +210,109 @@ def _mock_create_response(self, job_id: str) -> None:
HttpRequest(url=_EXPORT_URL),
HttpResponse(body=json.dumps({"id": job_id})),
)

def _create_async_job_repository(
self, polling_job_requester: Optional[HttpRequester] = None
) -> AsyncHttpJobRepository:
message_repository = Mock()
create_job_requester = HttpRequester(
name="stream <name>: create_job",
url_base=_URL_BASE,
path=_EXPORT_PATH,
error_handler=_ERROR_HANDLER,
http_method=HttpMethod.POST,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=False,
)
polling_job_requester = (
polling_job_requester
if polling_job_requester
else HttpRequester(
name="stream <name>: polling",
url_base=_URL_BASE,
path=_EXPORT_PATH + "/{{creation_response['id']}}",
error_handler=_ERROR_HANDLER,
http_method=HttpMethod.GET,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=False,
)
)

download_retriever = SimpleRetriever(
requester=HttpRequester(
name="stream <name>: fetch_result",
url_base="",
path="{{download_target}}",
error_handler=_ERROR_HANDLER,
http_method=HttpMethod.GET,
config=_ANY_CONFIG,
disable_retries=False,
parameters={},
message_repository=message_repository,
use_cache=False,
stream_response=True,
),
record_selector=RecordSelector(
extractor=ResponseToFileExtractor({}),
record_filter=None,
transformations=[],
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
config=_ANY_CONFIG,
parameters={},
),
primary_key=None,
name="any name",
paginator=DefaultPaginator(
decoder=NoopDecoder(),
page_size_option=None,
page_token_option=RequestOption(
field_name="locator",
inject_into=RequestOptionType.request_parameter,
parameters={},
),
pagination_strategy=CursorPaginationStrategy(
cursor_value="{{ headers['Sforce-Locator'] }}",
decoder=NoopDecoder(),
config=_ANY_CONFIG,
parameters={},
),
url_base=_URL_BASE,
config=_ANY_CONFIG,
parameters={},
),
config=_ANY_CONFIG,
parameters={},
)

return AsyncHttpJobRepository(
creation_requester=create_job_requester,
polling_requester=polling_job_requester,
download_retriever=download_retriever,
abort_requester=None,
delete_requester=None,
status_extractor=DpathExtractor(
decoder=JsonDecoder(parameters={}),
field_path=["status"],
config={},
parameters={} or {},
),
status_mapping={
"ready": AsyncJobStatus.COMPLETED,
"failure": AsyncJobStatus.FAILED,
"pending": AsyncJobStatus.RUNNING,
},
download_target_extractor=DpathExtractor(
decoder=JsonDecoder(parameters={}),
field_path=["urls"],
config={},
parameters={} or {},
),
)
Loading