Skip to content

Commit

Permalink
fix(agents-api): Fix updating task execution (#542)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Improves task execution updates by encoding task tokens with metadata
and enhancing error handling in `raise_complete_async.py` and
`update_execution.py`.
> 
>   - **Behavior**:
> - In `raise_complete_async.py`, task tokens are now base64 encoded and
include metadata with activity, run, and workflow IDs.
> - In `update_execution.py`, added error handling for stopping and
resuming executions, using metadata for async activity handle retrieval.
>   - **Database Queries**:
> - In `get_paused_execution_token.py`, query updated to include
`metadata` and sort by `created_at` with a limit of 1.
>   - **Error Handling**:
> - Added try-except blocks in `update_execution.py` to handle
exceptions when stopping or resuming executions.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for f3de527. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->

---------

Signed-off-by: Diwank Singh Tomer <[email protected]>
Co-authored-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
HamadaSalhab and creatorrr authored Oct 1, 2024
1 parent f092a49 commit 50a95ca
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
Expand All @@ -10,21 +11,27 @@

@activity.defn
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:
# TODO: Create a transtition to "wait" and save the captured_token to the transition

activity_info = activity.info()

captured_token = base64.b64encode(activity_info.task_token).decode('ascii')
activity_id = activity_info.activity_id
workflow_run_id = activity_info.workflow_run_id
workflow_id = activity_info.workflow_id

captured_token = activity.info().task_token
captured_token = captured_token.decode("latin-1")
transition_info = CreateTransitionRequest(
current=context.cursor,
type="wait",
next=None,
output=output,
task_token=captured_token,
metadata={
"x-activity-id": activity_id,
"x-run-id": workflow_run_id,
"x-workflow-id": workflow_id,
},
)

await original_transition_step(context, transition_info)

# await transition(context, output=output, type="wait", next=None, task_token=captured_token)

print("transition to wait called")
activity.raise_complete_async()
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_paused_execution_token(
"""

get_query = """
?[task_token, max(created_at)] :=
?[task_token, created_at, metadata] :=
execution_id = to_uuid($execution_id),
*executions {
execution_id,
Expand All @@ -59,9 +59,12 @@ def get_paused_execution_token(
created_at,
task_token,
type,
metadata,
},
type = "wait"
:sort -created_at
:limit 1
"""

queries = [
Expand Down
35 changes: 26 additions & 9 deletions agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Annotated
from uuid import UUID

Expand Down Expand Up @@ -29,19 +30,35 @@ async def update_execution(

match data:
case StopExecutionRequest():
wf_handle = temporal_client.get_workflow_handle_for(
*get_temporal_workflow_data(execution_id=execution_id)
)
await wf_handle.cancel()
try:
wf_handle = temporal_client.get_workflow_handle_for(
*get_temporal_workflow_data(execution_id=execution_id)
)
await wf_handle.cancel()
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to stop execution")

case ResumeExecutionRequest():
token_data = get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
act_handle = temporal_client.get_async_activity_handle(
task_token=str.encode(token_data["task_token"], encoding="latin-1")
)
await act_handle.complete(data.input)
print("Resumed execution successfully")
activity_id = token_data["metadata"].get("x-activity-id", None)
run_id = token_data["metadata"].get("x-run-id", None)
workflow_id = token_data["metadata"].get("x-workflow-id", None)
if activity_id is None or run_id is None or workflow_id is None:
act_handle = temporal_client.get_async_activity_handle(
task_token=base64.b64decode(token_data["task_token"].encode('ascii')),
)

else:
act_handle = temporal_client.get_async_activity_handle(
activity_id=activity_id,
workflow_id=workflow_id,
run_id=run_id,
)
try:
await act_handle.complete(data.input)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to resume execution")
case _:
raise HTTPException(status_code=400, detail="Invalid request data")

0 comments on commit 50a95ca

Please sign in to comment.