diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index c2de00c17c07..47d803d63191 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -7,7 +7,6 @@ from typing import ( Any, AsyncGenerator, - Callable, Coroutine, Dict, Generator, @@ -415,7 +414,7 @@ def create_flow_run(self, client: SyncPrefectClient) -> FlowRun: return flow_run - def call_hooks(self, state: Optional[State] = None) -> Iterable[Callable]: + def call_hooks(self, state: Optional[State] = None): if state is None: state = self.state flow = self.flow @@ -613,11 +612,7 @@ def start(self) -> Generator[None, None, None]: if self.state.is_running(): self.call_hooks() - try: - yield - finally: - if self.state.is_final() or self.state.is_cancelling(): - self.call_hooks() + yield @contextmanager def run_context(self): @@ -638,6 +633,9 @@ def run_context(self): except Exception as exc: self.logger.exception("Encountered exception during execution: %r", exc) self.handle_exception(exc) + finally: + if self.state.is_final() or self.state.is_cancelling(): + self.call_hooks() def call_flow_fn(self) -> Union[R, Coroutine[Any, Any, R]]: """ diff --git a/tests/test_flows.py b/tests/test_flows.py index 11fe47740ef7..8bb710eac461 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -31,6 +31,7 @@ IntervalSchedule, RRuleSchedule, ) +from prefect.context import FlowRunContext, get_run_context from prefect.deployments.runner import RunnerDeployment from prefect.docker.docker_image import DockerImage from prefect.events import DeploymentEventTrigger, Posture @@ -40,6 +41,7 @@ ParameterTypeError, ReservedArgumentError, ScriptError, + UnfinishedRun, ) from prefect.filesystems import LocalFileSystem from prefect.flows import ( @@ -60,6 +62,7 @@ ) from prefect.states import ( Cancelled, + Cancelling, Paused, PausedRun, State, @@ -2809,6 +2812,39 @@ async def my_hook(flow, flow_run, state): return my_hook +class TestFlowHooksContext: + @pytest.mark.parametrize( + "hook_type, fn_body, expected_exc", + [ + ("on_completion", lambda: None, None), + ("on_failure", lambda: 100 / 0, ZeroDivisionError), + ("on_cancellation", lambda: Cancelling(), UnfinishedRun), + ], + ) + def test_hooks_are_called_within_flow_run_context( + self, caplog, hook_type, fn_body, expected_exc + ): + def hook(flow, flow_run, state): + ctx: FlowRunContext = get_run_context() # type: ignore + assert ctx is not None + assert ctx.flow_run and ctx.flow_run == flow_run + assert ctx.flow_run.state == state + assert ctx.flow == flow + + @flow(**{hook_type: [hook]}) # type: ignore + def foo_flow(): + return fn_body() + + with caplog.at_level("INFO"): + if expected_exc: + with pytest.raises(expected_exc): + foo_flow() + else: + foo_flow() + + assert "Hook 'hook' finished running successfully" in caplog.text + + class TestFlowHooksWithKwargs: def test_hook_with_extra_default_arg(self): data = {}