diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 58bbf8e32dec..9ea19512afb6 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -113,18 +113,24 @@ jobs: - prefect-version: "2.13" server-incompatible: true server-disable-csrf: true + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' - prefect-version: "2.14" server-incompatible: true server-disable-csrf: true + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' - prefect-version: "2.15" server-incompatible: true server-disable-csrf: true + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' - prefect-version: "2.16" server-incompatible: false + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' - prefect-version: "2.17" server-incompatible: false + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' - prefect-version: "2.18" server-incompatible: false + extra_docker_run_options: '--env EXTRA_PIP_PACKAGES="prefect-kubernetes<0.4"' steps: - uses: actions/checkout@v4 diff --git a/requirements-client.txt b/requirements-client.txt index 2af268a02681..4cea46928cac 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -1,4 +1,4 @@ -anyio >= 3.7.1, < 4.0.0 +anyio >= 4.4.0, < 5.0.0 asgi-lifespan >= 1.0, < 3.0 cachetools >= 5.3, < 6.0 cloudpickle >= 2.0, < 4.0 diff --git a/requirements.txt b/requirements.txt index ba5a41cc4ae9..2e90043b175b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ pytz >= 2021.1, < 2025 readchar >= 4.0.0, < 5.0.0 sqlalchemy[asyncio] >= 1.4.22, != 1.4.33, < 3.0.0 typer >= 0.12.0, != 0.12.2, < 0.13.0 +exceptiongroup >= 1.2.1 diff --git a/scripts/wait-for-server.py b/scripts/wait-for-server.py index 47467afdbae8..2492b0758ec1 100755 --- a/scripts/wait-for-server.py +++ b/scripts/wait-for-server.py @@ -24,7 +24,7 @@ async def main(timeout): - async with anyio.move_on_after(timeout): + with anyio.move_on_after(timeout): print("Retrieving client...") async with get_client() as client: print("Connecting", end="") diff --git a/src/prefect/cli/agent.py b/src/prefect/cli/agent.py index e5061a845d0c..866e9e280e64 100644 --- a/src/prefect/cli/agent.py +++ b/src/prefect/cli/agent.py @@ -3,6 +3,7 @@ """ import os +from asyncio import CancelledError from functools import partial from typing import List, Optional from uuid import UUID @@ -18,6 +19,7 @@ from prefect.client import get_client from prefect.client.schemas.filters import WorkQueueFilter, WorkQueueFilterName from prefect.exceptions import ObjectNotFound +from prefect.logging import get_logger from prefect.settings import ( PREFECT_AGENT_PREFETCH_SECONDS, PREFECT_AGENT_QUERY_INTERVAL, @@ -26,6 +28,8 @@ from prefect.utilities.processutils import setup_signal_handlers_agent from prefect.utilities.services import critical_service_loop +logger = get_logger(__name__) + agent_app = PrefectTyper( name="agent", help="Commands for starting and interacting with agent processes.", @@ -219,30 +223,33 @@ async def start( f"queue(s): {', '.join(work_queues)}..." ) - async with anyio.create_task_group() as tg: - tg.start_soon( - partial( - critical_service_loop, - agent.get_and_submit_flow_runs, - PREFECT_AGENT_QUERY_INTERVAL.value(), - printer=app.console.print, - run_once=run_once, - jitter_range=0.3, - backoff=4, # Up to ~1 minute interval during backoff + try: + async with anyio.create_task_group() as tg: + tg.start_soon( + partial( + critical_service_loop, + agent.get_and_submit_flow_runs, + PREFECT_AGENT_QUERY_INTERVAL.value(), + printer=app.console.print, + run_once=run_once, + jitter_range=0.3, + backoff=4, # Up to ~1 minute interval during backoff + ) ) - ) - tg.start_soon( - partial( - critical_service_loop, - agent.check_for_cancelled_flow_runs, - PREFECT_AGENT_QUERY_INTERVAL.value() * 2, - printer=app.console.print, - run_once=run_once, - jitter_range=0.3, - backoff=4, + tg.start_soon( + partial( + critical_service_loop, + agent.check_for_cancelled_flow_runs, + PREFECT_AGENT_QUERY_INTERVAL.value() * 2, + printer=app.console.print, + run_once=run_once, + jitter_range=0.3, + backoff=4, + ) ) - ) + except CancelledError: + logger.debug("Agent task group cancelled") app.console.print("Agent stopped!") diff --git a/src/prefect/cli/server.py b/src/prefect/cli/server.py index 825eeaf93bf4..8ac4bfa52ab6 100644 --- a/src/prefect/cli/server.py +++ b/src/prefect/cli/server.py @@ -4,6 +4,7 @@ import os import textwrap +from asyncio import CancelledError from functools import partial import anyio @@ -115,44 +116,47 @@ async def start( base_url = f"http://{host}:{port}" - async with anyio.create_task_group() as tg: - app.console.print(generate_welcome_blurb(base_url, ui_enabled=ui)) - app.console.print("\n") - - server_process_id = await tg.start( - partial( - run_process, - command=[ - get_sys_executable(), - "-m", - "uvicorn", - "--app-dir", - # quote wrapping needed for windows paths with spaces - f'"{prefect.__module_path__.parent}"', - "--factory", - "prefect.server.api.server:create_app", - "--host", - str(host), - "--port", - str(port), - "--timeout-keep-alive", - str(keep_alive_timeout), - ], - env=server_env, - stream_output=True, + try: + async with anyio.create_task_group() as tg: + app.console.print(generate_welcome_blurb(base_url, ui_enabled=ui)) + app.console.print("\n") + + server_process_id = await tg.start( + partial( + run_process, + command=[ + get_sys_executable(), + "-m", + "uvicorn", + "--app-dir", + # quote wrapping needed for windows paths with spaces + f'"{prefect.__module_path__.parent}"', + "--factory", + "prefect.server.api.server:create_app", + "--host", + str(host), + "--port", + str(port), + "--timeout-keep-alive", + str(keep_alive_timeout), + ], + env=server_env, + stream_output=True, + ) ) - ) - # Explicitly handle the interrupt signal here, as it will allow us to - # cleanly stop the uvicorn server. Failing to do that may cause a - # large amount of anyio error traces on the terminal, because the - # SIGINT is handled by Typer/Click in this process (the parent process) - # and will start shutting down subprocesses: - # https://github.com/PrefectHQ/server/issues/2475 + # Explicitly handle the interrupt signal here, as it will allow us to + # cleanly stop the uvicorn server. Failing to do that may cause a + # large amount of anyio error traces on the terminal, because the + # SIGINT is handled by Typer/Click in this process (the parent process) + # and will start shutting down subprocesses: + # https://github.com/PrefectHQ/server/issues/2475 - setup_signal_handlers_server( - server_process_id, "the Prefect server", app.console.print - ) + setup_signal_handlers_server( + server_process_id, "the Prefect server", app.console.print + ) + except CancelledError: + logger.debug("Server task group cancelled") app.console.print("Server stopped!") diff --git a/src/prefect/cli/worker.py b/src/prefect/cli/worker.py index 9eb57a093478..871d07c35e1b 100644 --- a/src/prefect/cli/worker.py +++ b/src/prefect/cli/worker.py @@ -1,6 +1,7 @@ import json import os import threading +from asyncio import CancelledError from enum import Enum from functools import partial from typing import List, Optional, Type @@ -16,6 +17,7 @@ from prefect.client.orchestration import get_client from prefect.client.schemas.filters import WorkQueueFilter, WorkQueueFilterName from prefect.exceptions import ObjectNotFound +from prefect.logging import get_logger from prefect.plugins import load_prefect_collections from prefect.settings import ( PREFECT_WORKER_HEARTBEAT_SECONDS, @@ -32,6 +34,8 @@ from prefect.workers.base import BaseWorker from prefect.workers.server import start_healthcheck_server +logger = get_logger(__name__) + worker_app = PrefectTyper( name="worker", help="Commands for starting and interacting with workers." ) @@ -171,61 +175,64 @@ async def start( base_job_template=template_contents, ) as worker: app.console.print(f"Worker {worker.name!r} started!", style="green") - async with anyio.create_task_group() as tg: - # wait for an initial heartbeat to configure the worker - await worker.sync_with_backend() - # schedule the scheduled flow run polling loop - tg.start_soon( - partial( - critical_service_loop, - workload=worker.get_and_submit_flow_runs, - interval=PREFECT_WORKER_QUERY_SECONDS.value(), - run_once=run_once, - printer=app.console.print, - jitter_range=0.3, - backoff=4, # Up to ~1 minute interval during backoff + try: + async with anyio.create_task_group() as tg: + # wait for an initial heartbeat to configure the worker + await worker.sync_with_backend() + # schedule the scheduled flow run polling loop + tg.start_soon( + partial( + critical_service_loop, + workload=worker.get_and_submit_flow_runs, + interval=PREFECT_WORKER_QUERY_SECONDS.value(), + run_once=run_once, + printer=app.console.print, + jitter_range=0.3, + backoff=4, # Up to ~1 minute interval during backoff + ) ) - ) - # schedule the sync loop - tg.start_soon( - partial( - critical_service_loop, - workload=worker.sync_with_backend, - interval=worker.heartbeat_interval_seconds, - run_once=run_once, - printer=app.console.print, - jitter_range=0.3, - backoff=4, + # schedule the sync loop + tg.start_soon( + partial( + critical_service_loop, + workload=worker.sync_with_backend, + interval=worker.heartbeat_interval_seconds, + run_once=run_once, + printer=app.console.print, + jitter_range=0.3, + backoff=4, + ) ) - ) - tg.start_soon( - partial( - critical_service_loop, - workload=worker.check_for_cancelled_flow_runs, - interval=PREFECT_WORKER_QUERY_SECONDS.value() * 2, - run_once=run_once, - printer=app.console.print, - jitter_range=0.3, - backoff=4, + tg.start_soon( + partial( + critical_service_loop, + workload=worker.check_for_cancelled_flow_runs, + interval=PREFECT_WORKER_QUERY_SECONDS.value() * 2, + run_once=run_once, + printer=app.console.print, + jitter_range=0.3, + backoff=4, + ) ) - ) - started_event = await worker._emit_worker_started_event() - - # if --with-healthcheck was passed, start the healthcheck server - if with_healthcheck: - # we'll start the ASGI server in a separate thread so that - # uvicorn does not block the main thread - server_thread = threading.Thread( - name="healthcheck-server-thread", - target=partial( - start_healthcheck_server, - worker=worker, - query_interval_seconds=PREFECT_WORKER_QUERY_SECONDS.value(), - ), - daemon=True, - ) - server_thread.start() + started_event = await worker._emit_worker_started_event() + + # if --with-healthcheck was passed, start the healthcheck server + if with_healthcheck: + # we'll start the ASGI server in a separate thread so that + # uvicorn does not block the main thread + server_thread = threading.Thread( + name="healthcheck-server-thread", + target=partial( + start_healthcheck_server, + worker=worker, + query_interval_seconds=PREFECT_WORKER_QUERY_SECONDS.value(), + ), + daemon=True, + ) + server_thread.start() + except CancelledError: + logger.debug("Worker task group cancelled") await worker._emit_worker_stopped_event(started_event) app.console.print(f"Worker {worker.name!r} stopped!") diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 6d908c50567e..c96aa2cc9408 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -107,7 +107,7 @@ import anyio import pendulum -from anyio import start_blocking_portal +from anyio.from_thread import start_blocking_portal from typing_extensions import Literal import prefect @@ -211,6 +211,7 @@ _resolve_custom_task_run_name, capture_sigterm, check_api_reachable, + collapse_excgroups, collect_task_run_inputs, emit_task_run_state_change_event, propose_state, @@ -279,7 +280,7 @@ def enter_flow_run_engine_from_flow_call( # the user. Generally, you should enter contexts _within_ the async `begin_run` # instead but if you need to enter a context from the main thread you'll need to do # it here. - contexts = [capture_sigterm()] + contexts = [capture_sigterm(), collapse_excgroups()] if flow.isasync and ( not is_subflow_run or (is_subflow_run and parent_flow_run_context.flow.isasync) @@ -324,7 +325,7 @@ def enter_flow_run_engine_from_subprocess(flow_run_id: UUID) -> State: flow_run_id, user_thread=threading.current_thread(), ), - contexts=[capture_sigterm()], + contexts=[capture_sigterm(), collapse_excgroups()], ) APILogHandler.flush() @@ -2248,9 +2249,9 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: PrefectClient, flow This context _must_ reraise the exception to properly exit the run. """ - try: - yield + with collapse_excgroups(): + yield except (Abort, Pause): # Do not capture internal signals as crashes raise @@ -2287,7 +2288,8 @@ async def report_task_run_crashes(task_run: TaskRun, client: PrefectClient): This context _must_ reraise the exception to properly exit the run. """ try: - yield + with collapse_excgroups(): + yield except (Abort, Pause): # Do not capture internal signals as crashes raise diff --git a/src/prefect/task_server.py b/src/prefect/task_server.py index cdcb8fc4c825..3a28c81edcbc 100644 --- a/src/prefect/task_server.py +++ b/src/prefect/task_server.py @@ -9,6 +9,7 @@ from typing import List, Optional, Type import anyio +from exceptiongroup import BaseExceptionGroup # novermin from websockets.exceptions import InvalidStatusCode from prefect import Task, get_client @@ -225,16 +226,21 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun): validated_state=state, ) - self._runs_task_group.start_soon( - partial( - submit_autonomous_task_run_to_engine, - task=task, - task_run=task_run, - parameters=parameters, - task_runner=self.task_runner, - client=self._client, + try: + self._runs_task_group.start_soon( + partial( + submit_autonomous_task_run_to_engine, + task=task, + task_run=task_run, + parameters=parameters, + task_runner=self.task_runner, + client=self._client, + ) + ) + except BaseException as exc: + logger.exception( + f"Failed to submit task run {task_run.id!r} to engine", exc_info=exc ) - ) async def execute_task_run(self, task_run: TaskRun): """Execute a task run in the task server.""" @@ -301,6 +307,14 @@ def yell(message: str): try: await task_server.start() + except BaseExceptionGroup as exc: # novermin + exceptions = exc.exceptions + n_exceptions = len(exceptions) + logger.error( + f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:" + f"\n" + "\n".join(str(e) for e in exceptions) + ) + except StopTaskServer: logger.info("Task server stopped.") diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index ffbdd2f7155b..888d82f884c2 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -29,7 +29,9 @@ import anyio import anyio.abc +import anyio.to_thread import sniffio +from anyio.from_thread import start_blocking_portal from typing_extensions import Literal, ParamSpec, TypeGuard from prefect.logging import get_logger @@ -134,7 +136,7 @@ async def run_sync_in_worker_thread( """ call = partial(__fn, *args, **kwargs) return await anyio.to_thread.run_sync( - call, cancellable=True, limiter=get_thread_limiter() + call, abandon_on_cancel=True, limiter=get_thread_limiter() ) @@ -202,7 +204,7 @@ async def send_interrupt_to_thread(): partial( anyio.to_thread.run_sync, capture_worker_thread_and_result, - cancellable=True, + abandon_on_cancel=True, limiter=get_thread_limiter(), ) ) @@ -228,7 +230,7 @@ def run_async_in_new_loop(__fn: Callable[..., Awaitable[T]], *args: Any, **kwarg def in_async_worker_thread() -> bool: try: - anyio.from_thread.threadlocals.current_async_module + anyio.from_thread.threadlocals.current_async_backend except AttributeError: return False else: @@ -338,7 +340,7 @@ def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwarg "`sync` called from an asynchronous context; " "you should `await` the async function directly instead." ) - with anyio.start_blocking_portal() as portal: + with start_blocking_portal() as portal: return portal.call(partial(__async_fn, *args, **kwargs)) elif in_async_worker_thread(): # In a sync context but we can access the event loop thread; send the async diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index 1fdd794dbb1a..5701368fdd4c 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -4,11 +4,13 @@ import os import signal import time +from contextlib import contextmanager from functools import partial from typing import ( Any, Callable, Dict, + Generator, Iterable, Optional, Set, @@ -18,6 +20,7 @@ from uuid import UUID, uuid4 import anyio +from exceptiongroup import BaseExceptionGroup # novermin from typing_extensions import Literal import prefect @@ -734,3 +737,14 @@ def emit_task_run_state_change_event( }, follows=follows, ) + + +@contextmanager +def collapse_excgroups() -> Generator[None, None, None]: + try: + yield + except BaseException as exc: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc diff --git a/src/prefect/utilities/services.py b/src/prefect/utilities/services.py index 6cdfec5b8b74..cc2981c472ee 100644 --- a/src/prefect/utilities/services.py +++ b/src/prefect/utilities/services.py @@ -1,4 +1,5 @@ import sys +from asyncio import CancelledError from collections import deque from traceback import format_exception from types import TracebackType @@ -67,6 +68,11 @@ async def critical_service_loop( backoff_count = 0 track_record.append(True) + except CancelledError as exc: + # Exit immediately because the task was cancelled, possibly due + # to a signal or timeout. + logger.debug(f"Run of {workload!r} cancelled", exc_info=exc) + return except httpx.TransportError as exc: # httpx.TransportError is the base class for any kind of communications # error, like timeouts, connection failures, etc. This does _not_ cover @@ -138,7 +144,7 @@ async def critical_service_loop( failures.clear() printer( "Backing off due to consecutive errors, using increased interval of " - f" {interval * 2**backoff_count}s." + f" {interval * 2 ** backoff_count}s." ) if run_once: diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index e1f6d8f844c1..9a6a4ba6d72f 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -4,6 +4,7 @@ from unittest import mock import pytest +from exceptiongroup import BaseExceptionGroup, catch # novermin import prefect.results from prefect import Task, task, unmapped @@ -262,7 +263,13 @@ async def test_stuck_pending_tasks_are_reenqueued( # now we simulate a stuck task by having the TaskServer try to run it but fail server = TaskServer(foo_task_with_result_storage) - with pytest.raises(ValueError): + + def assert_exception(exc_group: BaseExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], ValueError) + assert "woops" in str(exc_group.exceptions[0]) + + with catch({ValueError: assert_exception}): with mock.patch( "prefect.task_server.submit_autonomous_task_run_to_engine", side_effect=ValueError("woops"), diff --git a/tests/test_context.py b/tests/test_context.py index c20368c3d2ee..620c6f26a687 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -388,7 +388,7 @@ def test_root_settings_context_accessible_in_new_thread(self): @pytest.mark.usefixtures("remove_existing_settings_context") def test_root_settings_context_accessible_in_new_loop(self): - from anyio import start_blocking_portal + from anyio.from_thread import start_blocking_portal with start_blocking_portal() as portal: result = portal.call(get_settings_context) diff --git a/tests/test_engine.py b/tests/test_engine.py index 11e817a7832f..4800ce4ae744 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,6 +12,7 @@ import anyio import pendulum import pytest +from exceptiongroup import BaseExceptionGroup, catch # novermin from prefect._internal.pydantic import HAS_PYDANTIC_V2 @@ -434,9 +435,14 @@ async def pausing_flow_without_blocking(): await foo(wait_for=[x, y]) assert False, "This line should not be reached" + def assert_exception(exc_group: BaseExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], Pause) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - with pytest.raises(Pause): + + with catch({Pause: assert_exception}): await pausing_flow_without_blocking(return_state=True) flow_run = await prefect_client.read_flow_run(flow_run_id) diff --git a/tests/utilities/test_asyncutils.py b/tests/utilities/test_asyncutils.py index d3929f3510c1..63a5e48ff29f 100644 --- a/tests/utilities/test_asyncutils.py +++ b/tests/utilities/test_asyncutils.py @@ -9,6 +9,7 @@ import anyio import pytest +from exceptiongroup import BaseExceptionGroup # novermin from prefect.context import ContextModel from prefect.settings import ( @@ -194,9 +195,12 @@ async def test_run_sync_in_interruptible_worker_thread_does_not_hide_exceptions( def foo(): raise ValueError("test") - with pytest.raises(ValueError, match="test"): + with pytest.raises(BaseExceptionGroup) as exc: await run_sync_in_interruptible_worker_thread(foo) + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], ValueError) + async def test_run_sync_in_interruptible_worker_thread_does_not_hide_base_exceptions(): class LikeKeyboardInterrupt(BaseException): @@ -205,9 +209,12 @@ class LikeKeyboardInterrupt(BaseException): def foo(): raise LikeKeyboardInterrupt("test") - with pytest.raises(LikeKeyboardInterrupt, match="test"): + with pytest.raises(BaseExceptionGroup) as exc: await run_sync_in_interruptible_worker_thread(foo) + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], LikeKeyboardInterrupt) + async def test_run_sync_in_interruptible_worker_thread_function_can_return_exception(): def foo(): diff --git a/tests/workers/test_process_worker.py b/tests/workers/test_process_worker.py index 58a84384e46e..0bd37ea895b0 100644 --- a/tests/workers/test_process_worker.py +++ b/tests/workers/test_process_worker.py @@ -11,6 +11,7 @@ import anyio.abc import pendulum import pytest +from exceptiongroup import BaseExceptionGroup # novermin from sqlalchemy.ext.asyncio import AsyncSession from prefect._internal.pydantic import HAS_PYDANTIC_V2 @@ -348,7 +349,8 @@ async def test_process_created_then_marked_as_started( patch_client(monkeypatch) fake_configuration = MagicMock() fake_configuration.command = "echo hello" - with pytest.raises(RuntimeError, match="Started called!"): + + with pytest.raises(BaseExceptionGroup) as exc: async with ProcessWorker( work_pool_name=work_pool.name, ) as worker: @@ -359,6 +361,9 @@ async def test_process_created_then_marked_as_started( task_status=fake_status, ) + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], RuntimeError) + fake_status.started.assert_called_once() mock_open_process.assert_awaited_once()