Skip to content

Commit

Permalink
Add locking around configured_personas_changed to ensure timestamps…
Browse files Browse the repository at this point in the history
… are in the same order of persona responses

Includes `update_time` in TTL cached query so cached responses include accurate timestamps
  • Loading branch information
NeonDaniel committed Dec 3, 2024
1 parent f1cb217 commit 7887540
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
42 changes: 25 additions & 17 deletions chat_server/server_utils/socketio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from time import time

from typing import Optional, List
from asyncio import Lock

from chat_server.server_utils.api_dependencies import (CurrentUserModel,
ListPersonasQueryModel)
from chat_server.sio.server import sio


_LOCK = Lock()


async def notify_personas_changed(supported_llms: Optional[List[str]] = None):
"""
Emit an SIO event for each LLM affected by a persona change. This sends a
Expand All @@ -44,20 +49,23 @@ async def notify_personas_changed(supported_llms: Optional[List[str]] = None):
then updates all LLMs listed in database configuration
"""
from chat_server.blueprints.personas import list_personas
resp = await list_personas(CurrentUserModel(_id="", nickname="",
first_name="", last_name=""),
ListPersonasQueryModel(only_enabled=True))
enabled_personas = json.loads(resp.body.decode())
valid_personas = {}
if supported_llms:
# Only broadcast updates for LLMs affected by an insert/change request
for llm in supported_llms:
valid_personas[llm] = [per for per in enabled_personas["items"] if
llm in per["supported_llms"]]
else:
# Delete request does not have LLM context, update everything
for persona in enabled_personas["items"]:
for llm in persona["supported_llms"]:
valid_personas.setdefault(llm, [])
valid_personas[llm].append(persona)
sio.emit("configured_personas_changed", {"personas": valid_personas})
async with _LOCK:
resp = await list_personas(CurrentUserModel(_id="", nickname="",
first_name="", last_name=""),
ListPersonasQueryModel(only_enabled=True))
update_time = time()
enabled_personas = json.loads(resp.body.decode())
valid_personas = {}
if supported_llms:
# Only broadcast updates for LLMs affected by an insert/change request
for llm in supported_llms:
valid_personas[llm] = [per for per in enabled_personas["items"] if
llm in per["supported_llms"]]
else:
# Delete request does not have LLM context, update everything
for persona in enabled_personas["items"]:
for llm in persona["supported_llms"]:
valid_personas.setdefault(llm, [])
valid_personas[llm].append(persona)
sio.emit("configured_personas_changed", {"personas": valid_personas,
"update_time": update_time})
5 changes: 4 additions & 1 deletion services/klatchat_observer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def _fetch_persona_api(self, user_id: Optional[str]) -> dict:
try:
response = self._fetch_klat_server(url=url)
data = response.json()
data['update_time'] = time.time()
self._refresh_default_persona_llms(data=data)
except KlatAPIAuthorizationError:
LOG.error(f"Failed to fetch personas from {url = }")
Expand All @@ -821,7 +822,9 @@ def _handle_personas_changed(self, data: dict):
"""
for llm, personas in data["personas"].items():
self.send_message(
request_data={"items": personas},
request_data={
"items": personas,
"update_time": data.get("update_time") or time.time()},
vhost=self.get_vhost("llm"),
queue=f"{llm}_personas_input",
expiration=5000,
Expand Down

0 comments on commit 7887540

Please sign in to comment.