Skip to content

Commit

Permalink
Introduce messages sub router
Browse files Browse the repository at this point in the history
  • Loading branch information
jankrepl committed Jan 15, 2025
1 parent 4ad2193 commit 2d3f7a3
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/neuroagent/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_update_kg_hierarchy,
)
from neuroagent.app.middleware import strip_path_prefix
from neuroagent.app.routers import qa, threads, tools
from neuroagent.app.routers import qa, threads, tools, messages

LOGGING = {
"version": 1,
Expand Down Expand Up @@ -145,6 +145,7 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
app.include_router(qa.router)
app.include_router(threads.router)
app.include_router(tools.router)
app.include_router(messages.router)


@app.get("/healthz")
Expand Down
57 changes: 57 additions & 0 deletions src/neuroagent/app/routers/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Message related CRUD operations."""

import json
import logging
from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from neuroagent.app.database.db_utils import get_thread
from neuroagent.app.database.schemas import MessagesRead
from neuroagent.app.database.sql_schemas import Entity, Messages, Threads
from neuroagent.app.dependencies import get_session
from neuroagent.app.routers.threads import router as threads_router

logger = logging.getLogger(__name__)

# Create a messages router
router = APIRouter()


# Define your routes here
@router.get("/")
async def get_thread_messages(
session: Annotated[AsyncSession, Depends(get_session)],
_: Annotated[Threads, Depends(get_thread)], # to check if thread exists
thread_id: str,
) -> list[MessagesRead]:
"""Get all messages of the thread."""
messages_result = await session.execute(
select(Messages)
.where(
Messages.thread_id == thread_id,
Messages.entity.in_([Entity.USER, Entity.AI_MESSAGE]),
)
.order_by(Messages.order)
)
db_messages = messages_result.scalars().all()

messages = []
for msg in db_messages:
messages.append(
MessagesRead(
msg_content=json.loads(msg.content)["content"],
**msg.__dict__,
)
)

return messages


# Include the messages router under threads at the end of the file
threads_router.include_router(
router,
prefix="/{thread_id}/messages",
)
33 changes: 2 additions & 31 deletions src/neuroagent/app/routers/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from neuroagent.app.app_utils import validate_project
from neuroagent.app.config import Settings
from neuroagent.app.database.db_utils import get_thread
from neuroagent.app.database.schemas import MessagesRead, ThreadsRead, ThreadUpdate
from neuroagent.app.database.sql_schemas import Entity, Messages, Threads
from neuroagent.app.database.schemas import ThreadsRead, ThreadUpdate
from neuroagent.app.database.sql_schemas import Threads
from neuroagent.app.dependencies import (
get_httpx_client,
get_kg_token,
Expand Down Expand Up @@ -73,35 +73,6 @@ async def get_threads(
return [ThreadsRead(**thread.__dict__) for thread in threads]


@router.get("/{thread_id}")
async def get_messages(
session: Annotated[AsyncSession, Depends(get_session)],
_: Annotated[Threads, Depends(get_thread)], # to check if thread exist
thread_id: str,
) -> list[MessagesRead]:
"""Get all messages of the thread."""
messages_result = await session.execute(
select(Messages)
.where(
Messages.thread_id == thread_id,
Messages.entity.in_([Entity.USER, Entity.AI_MESSAGE]),
)
.order_by(Messages.order)
)
db_messages = messages_result.scalars().all()

messages = []
for msg in db_messages:
messages.append(
MessagesRead(
msg_content=json.loads(msg.content)["content"],
**msg.__dict__,
)
)

return messages


@router.patch("/{thread_id}")
async def update_thread_title(
session: Annotated[AsyncSession, Depends(get_session)],
Expand Down

0 comments on commit 2d3f7a3

Please sign in to comment.