From 50a95ca3deeec537f5797bd534690d0eada4d43c Mon Sep 17 00:00:00 2001 From: Hamada Salhab Date: Tue, 1 Oct 2024 22:14:11 +0300 Subject: [PATCH] fix(agents-api): Fix updating task execution (#542) > [!IMPORTANT] > Improves task execution updates by encoding task tokens with metadata and enhancing error handling in `raise_complete_async.py` and `update_execution.py`. > > - **Behavior**: > - In `raise_complete_async.py`, task tokens are now base64 encoded and include metadata with activity, run, and workflow IDs. > - In `update_execution.py`, added error handling for stopping and resuming executions, using metadata for async activity handle retrieval. > - **Database Queries**: > - In `get_paused_execution_token.py`, query updated to include `metadata` and sort by `created_at` with a limit of 1. > - **Error Handling**: > - Added try-except blocks in `update_execution.py` to handle exceptions when stopping or resuming executions. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral) for f3de5274bf994fd43c1e9bceee13440256101806. It will automatically update as commits are pushed. --------- Signed-off-by: Diwank Singh Tomer Co-authored-by: Diwank Singh Tomer --- .../task_steps/raise_complete_async.py | 19 ++++++---- .../execution/get_paused_execution_token.py | 5 ++- .../routers/tasks/update_execution.py | 35 ++++++++++++++----- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py index 3944c914f..a73df3f8d 100644 --- a/agents-api/agents_api/activities/task_steps/raise_complete_async.py +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -1,3 +1,4 @@ +import base64 from temporalio import activity from ...autogen.openapi_model import CreateTransitionRequest @@ -10,21 +11,27 @@ @activity.defn async def raise_complete_async(context: StepContext, output: StepOutcome) -> None: - # TODO: Create a transtition to "wait" and save the captured_token to the transition + + activity_info = activity.info() + + captured_token = base64.b64encode(activity_info.task_token).decode('ascii') + activity_id = activity_info.activity_id + workflow_run_id = activity_info.workflow_run_id + workflow_id = activity_info.workflow_id - captured_token = activity.info().task_token - captured_token = captured_token.decode("latin-1") transition_info = CreateTransitionRequest( current=context.cursor, type="wait", next=None, output=output, task_token=captured_token, + metadata={ + "x-activity-id": activity_id, + "x-run-id": workflow_run_id, + "x-workflow-id": workflow_id, + }, ) await original_transition_step(context, transition_info) - # await transition(context, output=output, type="wait", next=None, task_token=captured_token) - - print("transition to wait called") activity.raise_complete_async() diff --git a/agents-api/agents_api/models/execution/get_paused_execution_token.py b/agents-api/agents_api/models/execution/get_paused_execution_token.py index d8b0945c1..b4c9f9081 100644 --- a/agents-api/agents_api/models/execution/get_paused_execution_token.py +++ b/agents-api/agents_api/models/execution/get_paused_execution_token.py @@ -49,7 +49,7 @@ def get_paused_execution_token( """ get_query = """ - ?[task_token, max(created_at)] := + ?[task_token, created_at, metadata] := execution_id = to_uuid($execution_id), *executions { execution_id, @@ -59,9 +59,12 @@ def get_paused_execution_token( created_at, task_token, type, + metadata, }, type = "wait" + :sort -created_at + :limit 1 """ queries = [ diff --git a/agents-api/agents_api/routers/tasks/update_execution.py b/agents-api/agents_api/routers/tasks/update_execution.py index 3f6b30e8c..a5ca30aab 100644 --- a/agents-api/agents_api/routers/tasks/update_execution.py +++ b/agents-api/agents_api/routers/tasks/update_execution.py @@ -1,3 +1,4 @@ +import base64 from typing import Annotated from uuid import UUID @@ -29,19 +30,35 @@ async def update_execution( match data: case StopExecutionRequest(): - wf_handle = temporal_client.get_workflow_handle_for( - *get_temporal_workflow_data(execution_id=execution_id) - ) - await wf_handle.cancel() + try: + wf_handle = temporal_client.get_workflow_handle_for( + *get_temporal_workflow_data(execution_id=execution_id) + ) + await wf_handle.cancel() + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to stop execution") case ResumeExecutionRequest(): token_data = get_paused_execution_token( developer_id=x_developer_id, execution_id=execution_id ) - act_handle = temporal_client.get_async_activity_handle( - task_token=str.encode(token_data["task_token"], encoding="latin-1") - ) - await act_handle.complete(data.input) - print("Resumed execution successfully") + activity_id = token_data["metadata"].get("x-activity-id", None) + run_id = token_data["metadata"].get("x-run-id", None) + workflow_id = token_data["metadata"].get("x-workflow-id", None) + if activity_id is None or run_id is None or workflow_id is None: + act_handle = temporal_client.get_async_activity_handle( + task_token=base64.b64decode(token_data["task_token"].encode('ascii')), + ) + + else: + act_handle = temporal_client.get_async_activity_handle( + activity_id=activity_id, + workflow_id=workflow_id, + run_id=run_id, + ) + try: + await act_handle.complete(data.input) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to resume execution") case _: raise HTTPException(status_code=400, detail="Invalid request data")