Skip to content

Commit

Permalink
refactor: queries abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jan 17, 2025
1 parent 590f874 commit 023b006
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 63 deletions.
8 changes: 3 additions & 5 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scalar_fastapi import get_scalar_api_reference

from .clients.pg import create_db_pool
from .env import api_prefix, hostname, pool_max_size, protocol, public_port
from .env import api_prefix, hostname, pg_dsn, pool_max_size, protocol, public_port
from .queries.container import Queries


Expand Down Expand Up @@ -86,10 +86,8 @@ def create_app():
container = Queries()
# FIXME: This does not work
# container.init_resources()
container.config.db.dsn.from_env(
"PG_DSN",
default="postgres://postgres:[email protected]:5432/postgres?sslmode=disable",
)
container.config.db.dsn.from_value(pg_dsn)
container.config.db.client_pool_max_size.from_value(pool_max_size)
app.container = container

return app
Expand Down
24 changes: 13 additions & 11 deletions agents-api/agents_api/queries/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from uuid import UUID

from typing import cast
from asyncpg import Record
from beartype import beartype
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateAgentRequest, ResourceCreatedResponse
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..base_query import BaseQuery
from ..base_queries import AsyncpgBaseQuery
from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class

# Define the raw SQL query
Expand Down Expand Up @@ -100,7 +102,7 @@ async def create_agent(
)


class CreateAgentQuery(BaseQuery):
class CreateAgentQuery(AsyncpgBaseQuery, metrics=increase_counter):
query = """
INSERT INTO agents (
developer_id,
Expand All @@ -127,8 +129,13 @@ class CreateAgentQuery(BaseQuery):
RETURNING *;
"""

@rewrap_exceptions(common_db_exceptions("agent", ["create"]))
async def _execute(self, conn, developer_id: UUID, data: CreateAgentRequest) -> ResourceCreatedResponse:
single_result = True
errors_mapping = common_db_exceptions("agent", ["create"])

def transform_record(self, rec: Record):
return {"id": rec["agent_id"], "created_at": rec["created_at"]}

async def execute(self, *, developer_id: UUID, data: CreateAgentRequest) -> ResourceCreatedResponse:
agent_id = uuid7()

# Ensure instructions is a list
Expand All @@ -154,10 +161,5 @@ async def _execute(self, conn, developer_id: UUID, data: CreateAgentRequest) ->
data.metadata,
default_settings,
]
result = await conn.fetchrow(self.query, *params)

return self.wrap_in_class(
[result],
ResourceCreatedResponse,
transform=lambda d: {"id": d["agent_id"], "created_at": d["created_at"]},
)[0]
async with self.pool.acquire() as conn, conn.transaction():
return await conn.fetch(self.query, *params)
44 changes: 0 additions & 44 deletions agents-api/agents_api/queries/base_query.py

This file was deleted.

5 changes: 3 additions & 2 deletions agents-api/agents_api/queries/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ async def _init_conn(conn):
)


async def init_pg_pool(dsn: str):
pool = await asyncpg.create_pool(dsn=dsn, init=_init_conn)
async def init_pg_pool(dsn: str, max_size: int):
pool = await asyncpg.create_pool(dsn=dsn, init=_init_conn, max_size=max_size)
yield pool
pool.close()

Expand All @@ -37,6 +37,7 @@ class Queries(containers.DeclarativeContainer):
db_pool = providers.Resource(
init_pg_pool,
dsn=config.db.dsn,
max_size=config.db.client_pool_max_size,
)
agents: AgentsQueriesContainer = providers.Container(
AgentsQueriesContainer,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ async def create_agent(
data: CreateAgentRequest,
query: CreateAgentQuery = Depends(Provide[Queries.agents.create])
) -> ResourceCreatedResponse:
# TODO: Validate model name
agent = await query.execute(
developer_id=x_developer_id,
data=data,
Expand Down

0 comments on commit 023b006

Please sign in to comment.