Skip to content

Commit

Permalink
feat: use create_concurrent_cursor_from_perpartition_cursor (#286)
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Inzhyyants <[email protected]>
  • Loading branch information
artem1205 authored Jan 30, 2025
1 parent dea2cc9 commit ee537af
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 38 deletions.
8 changes: 4 additions & 4 deletions airbyte_cdk/sources/declarative/async_job/job_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,16 @@ def _is_breaking_exception(self, exception: Exception) -> bool:
and exception.failure_type == FailureType.config_error
)

def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]:
"""
Fetches records from the given partition's jobs.
Fetches records from the given jobs.
Args:
partition (AsyncPartition): The partition containing the jobs.
async_jobs Iterable[AsyncJob]: The list of AsyncJobs.
Yields:
Iterable[Mapping[str, Any]]: The fetched records from the jobs.
"""
for job in partition.jobs:
for job in async_jobs:
yield from self._job_repository.fetch_records(job)
self._job_repository.delete(job)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from airbyte_cdk.sources.declarative.extractors.record_filter import (
ClientSideIncrementalRecordFilterDecorator,
)
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import (
PerPartitionWithGlobalCursor,
Expand Down Expand Up @@ -231,7 +232,7 @@ def _group_streams(
):
cursor = declarative_stream.retriever.stream_slicer.stream_slicer

if not isinstance(cursor, ConcurrentCursor):
if not isinstance(cursor, ConcurrentCursor | ConcurrentPerPartitionCursor):
# This should never happen since we instantiate ConcurrentCursor in
# model_to_component_factory.py
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion airbyte_cdk/sources/declarative/declarative_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def read_records(
"""
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
"""
if stream_slice is None or stream_slice == {}:
if stream_slice is None or (
not isinstance(stream_slice, StreamSlice) and stream_slice == {}
):
# As the parameter is Optional, many would just call `read_records(sync_mode)` during testing without specifying the field
# As part of the declarative model without custom components, this should never happen as the CDK would wire up a
# SinglePartitionRouter that would create this StreamSlice properly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def _build_stream_slicer_from_partition_router(
) -> Optional[PartitionRouter]:
if (
hasattr(model, "partition_router")
and isinstance(model, SimpleRetrieverModel)
and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel)
and model.partition_router
):
stream_slicer_model = model.partition_router
Expand Down Expand Up @@ -1690,6 +1690,31 @@ def _merge_stream_slicers(
stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config)

if model.incremental_sync and stream_slicer:
if model.retriever.type == "AsyncRetriever":
if model.incremental_sync.type != "DatetimeBasedCursor":
# We are currently in a transition to the Concurrent CDK and AsyncRetriever can only work with the support or unordered slices (for example, when we trigger reports for January and February, the report in February can be completed first). Once we have support for custom concurrent cursor or have a new implementation available in the CDK, we can enable more cursors here.
raise ValueError(
"AsyncRetriever with cursor other than DatetimeBasedCursor is not supported yet"
)
if stream_slicer:
return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
state_manager=self._connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=model.incremental_sync.__dict__,
stream_name=model.name or "",
stream_namespace=None,
config=config or {},
stream_state={},
partition_router=stream_slicer,
)
return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
model_type=DatetimeBasedCursorModel,
component_definition=model.incremental_sync.__dict__,
stream_name=model.name or "",
stream_namespace=None,
config=config or {},
)

incremental_sync_model = model.incremental_sync
if (
hasattr(incremental_sync_model, "global_substream_cursor")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Any, Callable, Iterable, Mapping, Optional

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
AsyncJobOrchestrator,
AsyncPartition,
)
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
SinglePartitionRouter,
Expand Down Expand Up @@ -42,12 +42,12 @@ def stream_slices(self) -> Iterable[StreamSlice]:

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
partition=dict(completed_partition.stream_slice.partition),
cursor_slice=completed_partition.stream_slice.cursor_slice,
extra_fields={"jobs": list(completed_partition.jobs)},
)

def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]:
"""
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
be responsible for. However, this was added in because the JobOrchestrator is required to
Expand All @@ -62,4 +62,4 @@ def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]
failure_type=FailureType.system_error,
)

return self._job_orchestrator.fetch_records(partition=partition)
return self._job_orchestrator.fetch_records(async_jobs=async_jobs)
18 changes: 6 additions & 12 deletions airbyte_cdk/sources/declarative/retrievers/async_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing_extensions import deprecated

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
Expand All @@ -16,7 +16,6 @@
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


@deprecated(
Expand Down Expand Up @@ -57,9 +56,9 @@ def _get_stream_state(self) -> StreamState:

return self.state

def _validate_and_get_stream_slice_partition(
def _validate_and_get_stream_slice_jobs(
self, stream_slice: Optional[StreamSlice] = None
) -> AsyncPartition:
) -> Iterable[AsyncJob]:
"""
Validates the stream_slice argument and returns the partition from it.
Expand All @@ -73,12 +72,7 @@ def _validate_and_get_stream_slice_partition(
AirbyteTracedException: If the stream_slice is not an instance of StreamSlice or if the partition is not present in the stream_slice.
"""
if not isinstance(stream_slice, StreamSlice) or "partition" not in stream_slice.partition:
raise AirbyteTracedException(
message="Invalid arguments to AsyncRetriever.read_records: stream_slice is not optional. Please contact Airbyte Support",
failure_type=FailureType.system_error,
)
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices
return stream_slice.extra_fields.get("jobs", []) if stream_slice else []

def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
return self.stream_slicer.stream_slices()
Expand All @@ -89,8 +83,8 @@ def read_records(
stream_slice: Optional[StreamSlice] = None,
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
jobs: Iterable[AsyncJob] = self._validate_and_get_stream_slice_jobs(stream_slice)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(jobs)

yield from self.record_selector.filter_and_transform(
all_data=records,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None:
orchestrator = self._orchestrator([_A_STREAM_SLICE])
first_job = _create_job()
second_job = _create_job()
partition = AsyncPartition([first_job, second_job], _A_STREAM_SLICE)

records = list(orchestrator.fetch_records(partition))
records = list(orchestrator.fetch_records([first_job, second_job]))

assert len(records) == 2
assert self._job_repository.fetch_records.mock_calls == [call(first_job), call(second_job)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def test_stream_slices_with_single_partition_router():

slices = list(partition_router.stream_slices())
assert len(slices) == 1
partition = slices[0].partition.get("partition")
assert isinstance(partition, AsyncPartition)
assert partition.stream_slice == StreamSlice(partition={}, cursor_slice={})
assert partition.status == AsyncJobStatus.COMPLETED
partition = slices[0]
assert isinstance(partition, StreamSlice)
assert partition == StreamSlice(partition={}, cursor_slice={})
assert partition.extra_fields["jobs"][0].status() == AsyncJobStatus.COMPLETED

attempts_per_job = list(partition.jobs)
attempts_per_job = list(partition.extra_fields["jobs"])
assert len(attempts_per_job) == 1
assert attempts_per_job[0].api_job_id() == "a_job_id"
assert attempts_per_job[0].job_parameters() == StreamSlice(partition={}, cursor_slice={})
Expand Down Expand Up @@ -68,14 +68,10 @@ def test_stream_slices_with_parent_slicer():
slices = list(partition_router.stream_slices())
assert len(slices) == 3
for i, partition in enumerate(slices):
partition = partition.partition.get("partition")
assert isinstance(partition, AsyncPartition)
assert partition.stream_slice == StreamSlice(
partition={"parent_id": str(i)}, cursor_slice={}
)
assert partition.status == AsyncJobStatus.COMPLETED
assert isinstance(partition, StreamSlice)
assert partition == StreamSlice(partition={"parent_id": str(i)}, cursor_slice={})

attempts_per_job = list(partition.jobs)
attempts_per_job = list(partition.extra_fields["jobs"])
assert len(attempts_per_job) == 1
assert attempts_per_job[0].api_job_id() == "a_job_id"
assert attempts_per_job[0].job_parameters() == StreamSlice(
Expand Down

0 comments on commit ee537af

Please sign in to comment.