Skip to content

Commit

Permalink
Merge branch 'dev' into d/cookbookfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Oct 17, 2024
2 parents dbac79f + c77a1a4 commit b25c42d
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 59 deletions.
24 changes: 24 additions & 0 deletions agents-api/agents_api/activities/sync_items_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ..common.protocol.remote import RemoteObject


@beartype
async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]:
from ..common.storage_handler import store_in_blob_store_if_large

return [store_in_blob_store_if_large(input) for input in inputs]


@beartype
async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]:
from ..common.storage_handler import load_from_blob_store_if_remote

return [load_from_blob_store_if_remote(input) for input in inputs]


save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn)
load_inputs_remote = activity.defn(name="load_inputs_remote")(load_inputs_remote_fn)
2 changes: 1 addition & 1 deletion agents-api/agents_api/clients/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_object(key: str, body: bytes, replace: bool = False) -> None:
client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)


@lru_cache(maxsize=256 * 1024 // blob_store_cutoff_kb) # 256mb in cache
@lru_cache(maxsize=256 * 1024 // max(1, blob_store_cutoff_kb)) # 256mb in cache
@beartype
def get_object(key: str) -> bytes:
client = get_s3_client()
Expand Down
7 changes: 6 additions & 1 deletion agents-api/agents_api/common/exceptions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def is_non_retryable_error(error: BaseException) -> bool:

# Check for specific HTTP errors (status code == 429)
if isinstance(error, httpx.HTTPStatusError):
if error.response.status_code in (408, 429, 503, 504):
if error.response.status_code in (
408,
429,
503,
504,
): # pytype: disable=attribute-error
return False

# If we don't know about the error, we should not retry
Expand Down
236 changes: 236 additions & 0 deletions agents-api/agents_api/common/protocol/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from dataclasses import dataclass
from typing import Any, Iterator

from temporalio import activity, workflow

with workflow.unsafe.imports_passed_through():
from pydantic import BaseModel

from ...env import blob_store_bucket


@dataclass
class RemoteObject:
key: str
bucket: str = blob_store_bucket


class BaseRemoteModel(BaseModel):
_remote_cache: dict[str, Any]

class Config:
arbitrary_types_allowed = True

def __init__(self, **data: Any):
super().__init__(**data)
self._remote_cache = {}

def __load_item(self, item: Any | RemoteObject) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import load_from_blob_store_if_remote

return load_from_blob_store_if_remote(item)

def __save_item(self, item: Any) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import store_in_blob_store_if_large

return store_in_blob_store_if_large(item)

def __getattribute__(self, name: str) -> Any:
if name.startswith("_"):
return super().__getattribute__(name)

try:
value = super().__getattribute__(name)
except AttributeError:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)

if isinstance(value, RemoteObject):
cache = super().__getattribute__("_remote_cache")
if name in cache:
return cache[name]

loaded_data = self.__load_item(value)
cache[name] = loaded_data
return loaded_data

return value

def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, value)
return

stored_value = self.__save_item(value)
super().__setattr__(name, stored_value)

if isinstance(stored_value, RemoteObject):
cache = self.__dict__.get("_remote_cache", {})
cache.pop(name, None)

def unload_attribute(self, name: str) -> None:
if name in self._remote_cache:
data = self._remote_cache.pop(name)
remote_obj = self.__save_item(data)
super().__setattr__(name, remote_obj)

def unload_all(self) -> None:
for name in list(self._remote_cache.keys()):
self.unload_attribute(name)


class RemoteList(list):
_remote_cache: dict[int, Any]

def __init__(self, iterable: list[Any] | None = None):
super().__init__()
self._remote_cache: dict[int, Any] = {}
if iterable:
for item in iterable:
self.append(item)

def __load_item(self, item: Any | RemoteObject) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import load_from_blob_store_if_remote

return load_from_blob_store_if_remote(item)

def __save_item(self, item: Any) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import store_in_blob_store_if_large

return store_in_blob_store_if_large(item)

def __getitem__(self, index: int | slice) -> Any:
if isinstance(index, slice):
# Obtain the slice without triggering __getitem__ recursively
sliced_items = super().__getitem__(
index
) # This returns a list of items as is
return RemoteList._from_existing_items(sliced_items)
else:
value = super().__getitem__(index)

if isinstance(value, RemoteObject):
if index in self._remote_cache:
return self._remote_cache[index]
loaded_data = self.__load_item(value)
self._remote_cache[index] = loaded_data
return loaded_data
return value

@classmethod
def _from_existing_items(cls, items: list[Any]) -> "RemoteList":
"""
Create a RemoteList from existing items without processing them again.
This method ensures that slicing does not trigger loading of items.
"""
new_remote_list = cls.__new__(
cls
) # Create a new instance without calling __init__
list.__init__(new_remote_list) # Initialize as an empty list
new_remote_list._remote_cache = {}
new_remote_list._extend_without_processing(items)
return new_remote_list

def _extend_without_processing(self, items: list[Any]) -> None:
"""
Extend the list without processing the items (i.e., without storing them again).
"""
super().extend(items)

def __setitem__(self, index: int | slice, value: Any) -> None:
if isinstance(index, slice):
# Handle slice assignment without processing existing RemoteObjects
processed_values = [self.__save_item(v) for v in value]
super().__setitem__(index, processed_values)
# Clear cache for affected indices
for i in range(*index.indices(len(self))):
self._remote_cache.pop(i, None)
else:
stored_value = self.__save_item(value)
super().__setitem__(index, stored_value)
self._remote_cache.pop(index, None)

def append(self, value: Any) -> None:
stored_value = self.__save_item(value)
super().append(stored_value)
# No need to cache immediately

def insert(self, index: int, value: Any) -> None:
stored_value = self.__save_item(value)
super().insert(index, stored_value)
# Adjust cache indices
self._shift_cache_on_insert(index)

def _shift_cache_on_insert(self, index: int) -> None:
new_cache = {}
for i, v in self._remote_cache.items():
if i >= index:
new_cache[i + 1] = v
else:
new_cache[i] = v
self._remote_cache = new_cache

def remove(self, value: Any) -> None:
# Find the index of the value to remove
index = self.index(value)
super().remove(value)
self._remote_cache.pop(index, None)
# Adjust cache indices
self._shift_cache_on_remove(index)

def _shift_cache_on_remove(self, index: int) -> None:
new_cache = {}
for i, v in self._remote_cache.items():
if i > index:
new_cache[i - 1] = v
elif i < index:
new_cache[i] = v
# Else: i == index, already removed
self._remote_cache = new_cache

def pop(self, index: int = -1) -> Any:
value = super().pop(index)
# Adjust negative indices
if index < 0:
index = len(self) + index
self._remote_cache.pop(index, None)
# Adjust cache indices
self._shift_cache_on_remove(index)
return value

def clear(self) -> None:
super().clear()
self._remote_cache.clear()

def extend(self, iterable: list[Any]) -> None:
for item in iterable:
self.append(item)

def __iter__(self) -> Iterator[Any]:
for index in range(len(self)):
yield self.__getitem__(index)

def unload_item(self, index: int) -> None:
"""Unload a specific item and replace it with a RemoteObject."""
if index in self._remote_cache:
data = self._remote_cache.pop(index)
remote_obj = self.__save_item(data)
super().__setitem__(index, remote_obj)

def unload_all(self) -> None:
"""Unload all cached items."""
for index in list(self._remote_cache.keys()):
self.unload_item(index)
64 changes: 31 additions & 33 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
from dataclasses import dataclass
from typing import Annotated, Any
from uuid import UUID

from pydantic import BaseModel, Field, computed_field
from pydantic_partial import create_partial_model

from ...autogen.openapi_model import (
Agent,
CreateTaskRequest,
CreateTransitionRequest,
Execution,
ExecutionStatus,
PartialTaskSpecDef,
PatchTaskRequest,
Session,
Task,
TaskSpec,
TaskSpecDef,
TaskToolDef,
Tool,
TransitionTarget,
TransitionType,
UpdateTaskRequest,
User,
Workflow,
WorkflowStep,
)
from temporalio import workflow

with workflow.unsafe.imports_passed_through():
from pydantic import BaseModel, Field, computed_field
from pydantic_partial import create_partial_model

from ...autogen.openapi_model import (
Agent,
CreateTaskRequest,
CreateTransitionRequest,
Execution,
ExecutionStatus,
PartialTaskSpecDef,
PatchTaskRequest,
Session,
Task,
TaskSpec,
TaskSpecDef,
TaskToolDef,
Tool,
TransitionTarget,
TransitionType,
UpdateTaskRequest,
User,
Workflow,
WorkflowStep,
)
from .remote import BaseRemoteModel, RemoteObject

# TODO: Maybe we should use a library for this

Expand Down Expand Up @@ -136,9 +139,9 @@ class ExecutionInput(BaseModel):
session: Session | None = None


class StepContext(BaseModel):
execution_input: ExecutionInput
inputs: list[Any]
class StepContext(BaseRemoteModel):
execution_input: ExecutionInput | RemoteObject
inputs: list[Any] | RemoteObject
cursor: TransitionTarget

@computed_field
Expand Down Expand Up @@ -216,11 +219,6 @@ class StepOutcome(BaseModel):
transition_to: tuple[TransitionType, TransitionTarget] | None = None


@dataclass
class RemoteObject:
key: str


def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
) -> TaskSpecDef | PartialTaskSpecDef:
Expand Down
Loading

0 comments on commit b25c42d

Please sign in to comment.