diff --git a/src/integrations/prefect-dask/prefect_dask/task_runners.py b/src/integrations/prefect-dask/prefect_dask/task_runners.py index 3b7e8608bf9a..5f565b88fd69 100644 --- a/src/integrations/prefect-dask/prefect_dask/task_runners.py +++ b/src/integrations/prefect-dask/prefect_dask/task_runners.py @@ -79,7 +79,6 @@ def count_to(highest_number): Coroutine, Dict, Iterable, - List, Optional, Set, TypeVar, @@ -91,7 +90,7 @@ def count_to(highest_number): from typing_extensions import ParamSpec from prefect.client.schemas.objects import State, TaskRunInput -from prefect.futures import PrefectFuture, PrefectWrappedFuture +from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture from prefect.logging.loggers import get_logger, get_run_logger from prefect.task_runners import TaskRunner from prefect.tasks import Task @@ -366,7 +365,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectDaskFuture[R]]: + ) -> PrefectFutureList[PrefectDaskFuture[R]]: ... @overload @@ -375,7 +374,7 @@ def map( task: "Task[Any, R]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectDaskFuture[R]]: + ) -> PrefectFutureList[PrefectDaskFuture[R]]: ... def map( diff --git a/src/integrations/prefect-ray/prefect_ray/task_runners.py b/src/integrations/prefect-ray/prefect_ray/task_runners.py index ff5b11961d9a..67f9dd1e1f59 100644 --- a/src/integrations/prefect-ray/prefect_ray/task_runners.py +++ b/src/integrations/prefect-ray/prefect_ray/task_runners.py @@ -78,7 +78,6 @@ def count_to(highest_number): Coroutine, Dict, Iterable, - List, Optional, Set, TypeVar, @@ -92,7 +91,7 @@ def count_to(highest_number): from prefect.client.schemas.objects import TaskRunInput from prefect.context import serialize_context -from prefect.futures import PrefectFuture, PrefectWrappedFuture +from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture from prefect.logging.loggers import get_logger, get_run_logger from prefect.states import State, exception_to_crashed_state from prefect.task_engine import run_task_async, run_task_sync @@ -291,7 +290,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectRayFuture[R]]: + ) -> PrefectFutureList[PrefectRayFuture[R]]: ... @overload @@ -300,7 +299,7 @@ def map( task: "Task[Any, R]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectRayFuture[R]]: + ) -> PrefectFutureList[PrefectRayFuture[R]]: ... def map( diff --git a/src/prefect/futures.py b/src/prefect/futures.py index 8ce20b416bdd..035572a50701 100644 --- a/src/prefect/futures.py +++ b/src/prefect/futures.py @@ -2,8 +2,9 @@ import concurrent.futures import inspect import uuid +from collections.abc import Iterator from functools import partial -from typing import Any, Generic, Optional, Set, Union, cast +from typing import Any, Generic, List, Optional, Set, Union, cast from typing_extensions import TypeVar @@ -16,6 +17,7 @@ from prefect.utilities.annotations import quote from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.collections import StopVisiting, visit_collection +from prefect.utilities.timeout import timeout as timeout_context F = TypeVar("F") R = TypeVar("R") @@ -62,7 +64,7 @@ def wait(self, timeout: Optional[float] = None) -> None: If the task run has already completed, this method will return immediately. Args: - - timeout: The maximum number of seconds to wait for the task run to complete. + timeout: The maximum number of seconds to wait for the task run to complete. If the task run has not completed after the timeout has elapsed, this method will return. """ @@ -79,9 +81,9 @@ def result( If the task run has not completed, this method will wait for the task run to complete. Args: - - timeout: The maximum number of seconds to wait for the task run to complete. + timeout: The maximum number of seconds to wait for the task run to complete. If the task run has not completed after the timeout has elapsed, this method will return. - - raise_on_failure: If `True`, an exception will be raised if the task run fails. + raise_on_failure: If `True`, an exception will be raised if the task run fails. Returns: The result of the task run. @@ -233,6 +235,63 @@ def __eq__(self, other): return self.task_run_id == other.task_run_id +class PrefectFutureList(list, Iterator, Generic[F]): + """ + A list of Prefect futures. + + This class provides methods to wait for all futures + in the list to complete and to retrieve the results of all task runs. + """ + + def wait(self, timeout: Optional[float] = None) -> None: + """ + Wait for all futures in the list to complete. + + Args: + timeout: The maximum number of seconds to wait for all futures to + complete. This method will not raise if the timeout is reached. + """ + try: + with timeout_context(timeout): + for future in self: + future.wait() + except TimeoutError: + logger.debug("Timed out waiting for all futures to complete.") + return + + def result( + self, + timeout: Optional[float] = None, + raise_on_failure: bool = True, + ) -> List: + """ + Get the results of all task runs associated with the futures in the list. + + Args: + timeout: The maximum number of seconds to wait for all futures to + complete. + raise_on_failure: If `True`, an exception will be raised if any task run fails. + + Returns: + A list of results of the task runs. + + Raises: + TimeoutError: If the timeout is reached before all futures complete. + """ + try: + with timeout_context(timeout): + return [ + future.result(raise_on_failure=raise_on_failure) for future in self + ] + except TimeoutError as exc: + # timeout came from inside the task + if "Scope timed out after {timeout} second(s)." not in str(exc): + raise + raise TimeoutError( + f"Timed out waiting for all futures to complete within {timeout} seconds" + ) from exc + + def resolve_futures_to_states( expr: Union[PrefectFuture, Any], ) -> Union[State, Any]: diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 2f16eeac5b0b..74f3d2e506c5 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -25,6 +25,7 @@ PrefectConcurrentFuture, PrefectDistributedFuture, PrefectFuture, + PrefectFutureList, ) from prefect.logging.loggers import get_logger, get_run_logger from prefect.utilities.annotations import allow_failure, quote, unmapped @@ -97,7 +98,7 @@ def map( task: "Task", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[F]: + ) -> PrefectFutureList[F]: """ Submit multiple tasks to the task run engine. @@ -169,7 +170,7 @@ def map( map_length = list(lengths)[0] - futures = [] + futures: List[PrefectFuture] = [] for i in range(map_length): call_parameters = { key: value[i] for key, value in iterable_parameters.items() @@ -199,7 +200,7 @@ def map( ) ) - return futures + return PrefectFutureList(futures) def __enter__(self): if self._started: @@ -316,7 +317,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectConcurrentFuture[R]]: + ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... @overload @@ -325,7 +326,7 @@ def map( task: "Task[Any, R]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectConcurrentFuture[R]]: + ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... def map( @@ -427,7 +428,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectDistributedFuture[R]]: + ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... @overload @@ -436,7 +437,7 @@ def map( task: "Task[Any, R]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> List[PrefectDistributedFuture[R]]: + ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... def map( diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 7ffc4a8afdea..fa9cfbd5ef7d 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -43,7 +43,7 @@ TaskRunContext, serialize_context, ) -from prefect.futures import PrefectDistributedFuture, PrefectFuture +from prefect.futures import PrefectDistributedFuture, PrefectFuture, PrefectFutureList from prefect.logging.loggers import get_logger from prefect.results import ResultFactory, ResultSerializer, ResultStorage from prefect.settings import ( @@ -996,7 +996,7 @@ def map( self: "Task[P, NoReturn]", *args: P.args, **kwargs: P.kwargs, - ) -> List[PrefectFuture[NoReturn]]: + ) -> PrefectFutureList[PrefectFuture[NoReturn]]: ... @overload @@ -1004,7 +1004,7 @@ def map( self: "Task[P, Coroutine[Any, Any, T]]", *args: P.args, **kwargs: P.kwargs, - ) -> List[PrefectFuture[T]]: + ) -> PrefectFutureList[PrefectFuture[T]]: ... @overload @@ -1012,7 +1012,7 @@ def map( self: "Task[P, T]", *args: P.args, **kwargs: P.kwargs, - ) -> List[PrefectFuture[T]]: + ) -> PrefectFutureList[PrefectFuture[T]]: ... @overload @@ -1021,7 +1021,7 @@ def map( *args: P.args, return_state: Literal[True], **kwargs: P.kwargs, - ) -> List[State[T]]: + ) -> PrefectFutureList[State[T]]: ... @overload @@ -1030,7 +1030,7 @@ def map( *args: P.args, return_state: Literal[True], **kwargs: P.kwargs, - ) -> List[State[T]]: + ) -> PrefectFutureList[State[T]]: ... def map( @@ -1044,8 +1044,9 @@ def map( """ Submit a mapped run of the task to a worker. - Must be called within a flow function. If writing an async task, this - call must be awaited. + Must be called within a flow run context. Will return a list of futures + that should be waited on before exiting the flow context to ensure all + mapped tasks have completed. Must be called with at least one iterable and all iterables must be the same length. Any arguments that are not iterable will be treated as @@ -1083,15 +1084,14 @@ def map( >>> from prefect import flow >>> @flow >>> def my_flow(): - >>> my_task.map([1, 2, 3]) + >>> return my_task.map([1, 2, 3]) Wait for all mapped tasks to finish >>> @flow >>> def my_flow(): >>> futures = my_task.map([1, 2, 3]) - >>> for future in futures: - >>> future.wait() + >>> futures.wait(): >>> # Now all of the mapped tasks have finished >>> my_task(10) @@ -1100,8 +1100,8 @@ def map( >>> @flow >>> def my_flow(): >>> futures = my_task.map([1, 2, 3]) - >>> for future in futures: - >>> print(future.result()) + >>> for x in futures.result(): + >>> print(x) >>> my_flow() 2 3 @@ -1122,6 +1122,7 @@ def map( >>> >>> # task 2 will wait for task_1 to complete >>> y = task_2.map([1, 2, 3], wait_for=[x]) + >>> return y Use a non-iterable input as a constant across mapped tasks >>> @task @@ -1130,7 +1131,7 @@ def map( >>> >>> @flow >>> def my_flow(): - >>> display.map("Check it out: ", [1, 2, 3]) + >>> return display.map("Check it out: ", [1, 2, 3]) >>> >>> my_flow() Check it out: 1 diff --git a/tests/test_futures.py b/tests/test_futures.py index 5d6dddaa3219..a5be27d6b399 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -3,7 +3,7 @@ from collections import OrderedDict from concurrent.futures import Future from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, List, Optional import pytest @@ -13,6 +13,8 @@ from prefect.futures import ( PrefectConcurrentFuture, PrefectDistributedFuture, + PrefectFuture, + PrefectFutureList, PrefectWrappedFuture, resolve_futures_to_states, ) @@ -280,3 +282,79 @@ def my_task(): with pytest.raises(MissingResult, match="State data is missing"): future.result() + + +class TestPrefectFutureList: + def test_wait(self): + mock_futures = [MockFuture(data=i) for i in range(5)] + futures = PrefectFutureList(mock_futures) + # should not raise a TimeoutError + futures.wait() + + for future in futures: + assert future.state.is_completed() + + @pytest.mark.timeout(method="thread") # alarm-based pytest-timeout will interfere + def test_wait_with_timeout(self): + mock_futures: List[PrefectFuture] = [MockFuture(data=i) for i in range(5)] + hanging_future = Future() + mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), hanging_future)) + futures = PrefectFutureList(mock_futures) + # should not raise a TimeoutError or hang + futures.wait(timeout=0.01) + + def test_results(self): + mock_futures = [MockFuture(data=i) for i in range(5)] + futures = PrefectFutureList(mock_futures) + result = futures.result() + + for i, result in enumerate(result): + assert result == i + + def test_results_with_failure(self): + mock_futures: List[PrefectFuture] = [MockFuture(data=i) for i in range(5)] + failing_future = Future() + failing_future.set_exception(ValueError("oops")) + mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), failing_future)) + futures = PrefectFutureList(mock_futures) + + with pytest.raises(ValueError, match="oops"): + futures.result() + + def test_results_with_raise_on_failure_false(self): + mock_futures: List[PrefectFuture] = [MockFuture(data=i) for i in range(5)] + final_state = Failed(data=ValueError("oops")) + wrapped_future = Future() + wrapped_future.set_result(final_state) + mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), wrapped_future)) + futures = PrefectFutureList(mock_futures) + + result = futures.result(raise_on_failure=False) + + for i, result in enumerate(result): + if i == 5: + assert isinstance(result, ValueError) + else: + assert result == i + + @pytest.mark.timeout(method="thread") # alarm-based pytest-timeout will interfere + def test_results_with_timeout(self): + mock_futures: List[PrefectFuture] = [MockFuture(data=i) for i in range(5)] + failing_future = Future() + failing_future.set_exception(TimeoutError("oops")) + mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), failing_future)) + futures = PrefectFutureList(mock_futures) + + with pytest.raises(TimeoutError): + futures.result(timeout=0.01) + + def test_result_does_not_obscure_other_timeouts(self): + mock_futures: List[PrefectFuture] = [MockFuture(data=i) for i in range(5)] + final_state = Failed(data=TimeoutError("oops")) + wrapped_future = Future() + wrapped_future.set_result(final_state) + mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), wrapped_future)) + futures = PrefectFutureList(mock_futures) + + with pytest.raises(TimeoutError, match="oops"): + futures.result() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a4ef67dd2573..650e10917e33 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -3581,6 +3581,33 @@ async def test_flow(): await test_flow() + async def test_wait_mapped_tasks(self): + @task + def add_one(x): + return x + 1 + + @flow + def my_flow(): + futures = add_one.map([1, 2, 3]) + futures.wait() + for future in futures: + assert future.state.is_completed() + + my_flow() + + async def test_get_results_all_mapped_tasks(self): + @task + def add_one(x): + return x + 1 + + @flow + def my_flow(): + futures = add_one.map([1, 2, 3]) + results = futures.result() + assert results == [2, 3, 4] + + my_flow() + class TestTaskConstructorValidation: async def test_task_cannot_configure_too_many_custom_retry_delays(self):