diff --git a/taskiq_pipelines/abc.py b/taskiq_pipelines/abc.py index 3c43dd0..518b4b1 100644 --- a/taskiq_pipelines/abc.py +++ b/taskiq_pipelines/abc.py @@ -43,6 +43,7 @@ async def act( self, broker: AsyncBroker, step_number: int, + parent_task_id: str, task_id: str, pipe_data: str, result: "TaskiqResult[Any]", @@ -57,6 +58,7 @@ async def act( :param broker: current broker. :param step_number: current step number. + :param parent_task_id: current task id. :param task_id: task_id to use. :param pipe_data: serialized pipeline must be in labels. :param result: result of a previous task. diff --git a/taskiq_pipelines/middleware.py b/taskiq_pipelines/middleware.py index 2932385..fe687f0 100644 --- a/taskiq_pipelines/middleware.py +++ b/taskiq_pipelines/middleware.py @@ -63,6 +63,7 @@ async def post_execute( # noqa: C901, WPS212 await next_step.act( broker=self.broker, step_number=current_step_num + 1, + parent_task_id=message.task_id, task_id=next_step_data.task_id, pipe_data=pipeline_data, result=result, diff --git a/taskiq_pipelines/pipeliner.py b/taskiq_pipelines/pipeliner.py index 3fe416a..58dee2f 100644 --- a/taskiq_pipelines/pipeliner.py +++ b/taskiq_pipelines/pipeliner.py @@ -8,7 +8,7 @@ from typing_extensions import ParamSpec from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA -from taskiq_pipelines.steps import MapperStep, SequentialStep, parse_step +from taskiq_pipelines.steps import FilterStep, MapperStep, SequentialStep, parse_step _ReturnType = TypeVar("_ReturnType") _FuncParams = ParamSpec("_FuncParams") @@ -182,6 +182,78 @@ def map( ) return self + @overload + def filter( + self: "Pipeline[_FuncParams, _ReturnType]", + task: Union[ + AsyncKicker[Any, Coroutine[Any, Any, bool]], + AsyncTaskiqDecoratedTask[Any, Coroutine[Any, Any, bool]], + ], + param_name: Optional[str] = None, + skip_errors: bool = False, + check_interval: float = 0.5, + **additional_kwargs: Any, + ) -> "Pipeline[_FuncParams, _ReturnType]": + ... + + @overload + def filter( + self: "Pipeline[_FuncParams, _ReturnType]", + task: Union[ + AsyncKicker[Any, bool], + AsyncTaskiqDecoratedTask[Any, bool], + ], + param_name: Optional[str] = None, + skip_errors: bool = False, + check_interval: float = 0.5, + **additional_kwargs: Any, + ) -> "Pipeline[_FuncParams, _ReturnType]": + ... + + def filter( + self, + task: Union[ + AsyncKicker[Any, Any], + AsyncTaskiqDecoratedTask[Any, Any], + ], + param_name: Optional[str] = None, + skip_errors: bool = False, + check_interval: float = 0.5, + **additional_kwargs: Any, + ) -> Any: + """ + Add filter step. + + This step is executed on a list of items, + like map. + + It runs many small subtasks for each item + in sequence and if task returns true, + the result is added to the final list. + + :param task: task to execute on every item. + :param param_name: parameter name to pass item into, defaults to None + :param skip_errors: skip errors if any, defaults to False + :param check_interval: how often the result of all subtasks is checked, + defaults to 0.5 + :param additional_kwargs: additional function's kwargs. + :return: pipeline with filtering step. + """ + self.steps.append( + DumpedStep( + step_type=FilterStep.step_name, + step_data=FilterStep.from_task( + task=task, + param_name=param_name, + skip_errors=skip_errors, + check_interval=check_interval, + **additional_kwargs, + ).dumps(), + task_id="", + ), + ) + return self + def dumps(self) -> str: """ Dumps current pipeline as string. diff --git a/taskiq_pipelines/steps/__init__.py b/taskiq_pipelines/steps/__init__.py index 409f9e5..b17f08e 100644 --- a/taskiq_pipelines/steps/__init__.py +++ b/taskiq_pipelines/steps/__init__.py @@ -2,6 +2,7 @@ from logging import getLogger from taskiq_pipelines.abc import AbstractStep +from taskiq_pipelines.steps.filter import FilterStep from taskiq_pipelines.steps.mapper import MapperStep from taskiq_pipelines.steps.sequential import SequentialStep @@ -19,4 +20,5 @@ def parse_step(step_type: str, step_data: str) -> AbstractStep: __all__ = [ "MapperStep", "SequentialStep", + "FilterStep", ] diff --git a/taskiq_pipelines/steps/filter.py b/taskiq_pipelines/steps/filter.py new file mode 100644 index 0000000..c698d59 --- /dev/null +++ b/taskiq_pipelines/steps/filter.py @@ -0,0 +1,185 @@ +import asyncio +from typing import Any, Dict, Iterable, List, Optional, Union + +import pydantic +from taskiq import AsyncBroker, TaskiqError, TaskiqResult +from taskiq.brokers.shared_broker import async_shared_broker +from taskiq.context import Context, default_context +from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.kicker import AsyncKicker + +from taskiq_pipelines.abc import AbstractStep +from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA +from taskiq_pipelines.exceptions import AbortPipeline + + +@async_shared_broker.task(task_name="taskiq_pipelines.shared.filter_tasks") +async def filter_tasks( # noqa: C901, WPS210, WPS231 + task_ids: List[str], + parent_task_id: str, + check_interval: float, + context: Context = default_context, + skip_errors: bool = False, +) -> List[Any]: + """ + Filter resulted tasks. + + It takes list of task ids, + and parent task id. + + After all subtasks are completed it gets + result of a parent task, and + if subtask's result of execution can be + converted to True, the item from the original + tasks is added to the resulting array. + + :param task_ids: ordered list of task ids. + :param parent_task_id: task id of a parent task. + :param check_interval: how often checks are performed. + :param context: context of the execution, defaults to default_context + :param skip_errors: skip errors of subtasks, defaults to False + :raises TaskiqError: if any subtask has returned error. + :return: fitlered results. + """ + ordered_ids = task_ids[:] + tasks_set = set(task_ids) + while tasks_set: + for task_id in task_ids: # noqa: WPS327 + if await context.broker.result_backend.is_result_ready(task_id): + try: + tasks_set.remove(task_id) + except LookupError: + continue + await asyncio.sleep(check_interval) + + results = await context.broker.result_backend.get_result(parent_task_id) + filtered_results = [] + for task_id, value in zip( # type: ignore # noqa: WPS352, WPS440 + ordered_ids, + results.return_value, + ): + result = await context.broker.result_backend.get_result(task_id) + if result.is_err: + if skip_errors: + continue + raise TaskiqError(f"Task {task_id} returned error. Filtering failed.") + if result.return_value: + filtered_results.append(value) + return filtered_results + + +class FilterStep(pydantic.BaseModel, AbstractStep, step_name="filter"): + """Task to filter results.""" + + task_name: str + labels: Dict[str, str] + param_name: Optional[str] + additional_kwargs: Dict[str, Any] + skip_errors: bool + check_interval: float + + def dumps(self) -> str: + """ + Dumps step as string. + + :return: returns json. + """ + return self.json() + + @classmethod + def loads(cls, data: str) -> "FilterStep": + """ + Parses mapper step from string. + + :param data: dumped data. + :return: parsed step. + """ + return pydantic.parse_raw_as(FilterStep, data) + + async def act( + self, + broker: AsyncBroker, + step_number: int, + parent_task_id: str, + task_id: str, + pipe_data: str, + result: "TaskiqResult[Any]", + ) -> None: + """ + Run filter action. + + This function creates many small filter steps, + and then collects all results in one big filtered array, + using 'filter_tasks' shared task. + + :param broker: current broker. + :param step_number: current step number. + :param parent_task_id: task_id of the previous step. + :param task_id: task_id to use in this step. + :param pipe_data: serialized pipeline. + :param result: result of the previous task. + :raises AbortPipeline: if result is not iterable. + """ + if not isinstance(result.return_value, Iterable): + raise AbortPipeline("Result of the previous task is not iterable.") + sub_task_ids = [] + for item in result.return_value: + kicker: "AsyncKicker[Any, Any]" = AsyncKicker( + task_name=self.task_name, + broker=broker, + labels=self.labels, + ) + if self.param_name: + self.additional_kwargs[self.param_name] = item + task = await kicker.kiq(**self.additional_kwargs) + else: + task = await kicker.kiq(item, **self.additional_kwargs) + sub_task_ids.append(task.task_id) + + await filter_tasks.kicker().with_task_id(task_id).with_broker( + broker, + ).with_labels( + **{CURRENT_STEP: step_number, PIPELINE_DATA: pipe_data}, # type: ignore + ).kiq( + sub_task_ids, + parent_task_id, + check_interval=self.check_interval, + skip_errors=self.skip_errors, + ) + + @classmethod + def from_task( + cls, + task: Union[ + AsyncKicker[Any, Any], + AsyncTaskiqDecoratedTask[Any, Any], + ], + param_name: Optional[str], + skip_errors: bool, + check_interval: float, + **additional_kwargs: Any, + ) -> "FilterStep": + """ + Create new filter step from task. + + :param task: task to execute. + :param param_name: parameter name. + :param skip_errors: don't fail collector + task on errors. + :param check_interval: how often tasks are checked. + :param additional_kwargs: additional function's kwargs. + :return: new mapper step. + """ + if isinstance(task, AsyncTaskiqDecoratedTask): + kicker = task.kicker() + else: + kicker = task + message = kicker._prepare_message() # noqa: WPS437 + return FilterStep( + task_name=message.task_name, + labels=message.labels, + param_name=param_name, + additional_kwargs=additional_kwargs, + skip_errors=skip_errors, + check_interval=check_interval, + ) diff --git a/taskiq_pipelines/steps/mapper.py b/taskiq_pipelines/steps/mapper.py index d2a5083..802fdc2 100644 --- a/taskiq_pipelines/steps/mapper.py +++ b/taskiq_pipelines/steps/mapper.py @@ -17,8 +17,8 @@ from taskiq_pipelines.exceptions import AbortPipeline -@async_shared_broker.task(task_name="taskiq_pipelines.wait_tasks") -async def wait_tasks( # noqa: C901 +@async_shared_broker.task(task_name="taskiq_pipelines.shared.wait_tasks") +async def wait_tasks( # noqa: C901, WPS231 task_ids: List[str], check_interval: float, context: Context = default_context, @@ -44,9 +44,12 @@ async def wait_tasks( # noqa: C901 ordered_ids = task_ids[:] tasks_set = set(task_ids) while tasks_set: - for task_id in task_ids: + for task_id in task_ids: # noqa: WPS327 if await context.broker.result_backend.is_result_ready(task_id): - tasks_set.remove(task_id) + try: + tasks_set.remove(task_id) + except LookupError: + continue await asyncio.sleep(check_interval) results = [] @@ -92,6 +95,7 @@ async def act( self, broker: AsyncBroker, step_number: int, + parent_task_id: str, task_id: str, pipe_data: str, result: "TaskiqResult[Any]", @@ -109,6 +113,7 @@ async def act( :param broker: current broker. :param step_number: current step number. :param task_id: waiter task_id. + :param parent_task_id: task_id of the previous step. :param pipe_data: serialized pipeline. :param result: result of the previous task. :raises AbortPipeline: if the result of the diff --git a/taskiq_pipelines/steps/sequential.py b/taskiq_pipelines/steps/sequential.py index f715f58..e81b82b 100644 --- a/taskiq_pipelines/steps/sequential.py +++ b/taskiq_pipelines/steps/sequential.py @@ -44,6 +44,7 @@ async def act( self, broker: AsyncBroker, step_number: int, + parent_task_id: str, task_id: str, pipe_data: str, result: "TaskiqResult[Any]", @@ -61,6 +62,7 @@ async def act( :param broker: current broker. :param step_number: current step number. + :param parent_task_id: current step's task id. :param task_id: new task id. :param pipe_data: serialized pipeline. :param result: result of the previous task.