|
| 1 | +import sys |
1 | 2 | from typing import List |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | from taskiq import InMemoryBroker |
5 | 6 |
|
6 | | -from taskiq_pipelines import Pipeline, PipelineMiddleware |
| 7 | +from taskiq_pipelines import Pipeline, PipelineMiddleware, AbortPipeline |
7 | 8 |
|
8 | 9 |
|
9 | 10 | @pytest.mark.anyio |
@@ -42,3 +43,32 @@ def double(i: int) -> int: |
42 | 43 | sent = await pipe.kiq(4) |
43 | 44 | res = await sent.wait_result() |
44 | 45 | assert res.return_value == list(map(double, ranger(4))) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.anyio |
| 49 | +async def test_abort_pipeline() -> None: |
| 50 | + """Test AbortPipeline.""" |
| 51 | + broker = InMemoryBroker().with_middlewares(PipelineMiddleware()) |
| 52 | + text = 'task was aborted' |
| 53 | + |
| 54 | + @broker.task |
| 55 | + def normal_task(i: bool) -> bool: |
| 56 | + return i |
| 57 | + |
| 58 | + @broker.task |
| 59 | + def aborting_task(i: int) -> bool: |
| 60 | + if i: |
| 61 | + raise AbortPipeline(text) |
| 62 | + return True |
| 63 | + |
| 64 | + pipe = Pipeline(broker, aborting_task).call_next(normal_task) |
| 65 | + sent = await pipe.kiq(0) |
| 66 | + res = await sent.wait_result() |
| 67 | + assert res.is_err is False |
| 68 | + assert res.return_value is True |
| 69 | + assert res.error is None |
| 70 | + sent = await pipe.kiq(1) |
| 71 | + res = await sent.wait_result() |
| 72 | + assert res.is_err is True |
| 73 | + assert res.return_value is None |
| 74 | + assert res.error.args[0] == text |
0 commit comments