diff --git a/chat_server/server_utils/models/personas.py b/chat_server/server_utils/models/personas.py index 2ba721a9..a5d23bcb 100644 --- a/chat_server/server_utils/models/personas.py +++ b/chat_server/server_utils/models/personas.py @@ -48,13 +48,19 @@ def persona_id(self): class AddPersonaModel(Persona): - supported_llms: list[str] = Field(examples=[["chatgpt", "llama", "fastchat"]]) + supported_llms: list[str] = Field( + examples=[["chat_gpt", "llama", "fastchat"]], default=[] + ) + default_llm: str | None = Field(examples=["chat_gpt"], default=None) description: str = Field(examples=["I am the doctor. I am helping people."]) enabled: bool = False class SetPersonaModel(Persona): - supported_llms: list[str] = Field(examples=[["chatgpt", "llama", "fastchat"]]) + supported_llms: list[str] = Field( + examples=[["chat_gpt", "llama", "fastchat"]], default=[] + ) + default_llm: str | None = Field(examples=["chat_gpt"], default=None) description: str = Field(examples=["I am the doctor. I am helping people."]) diff --git a/services/klatchat_observer/controller.py b/services/klatchat_observer/controller.py index 3c058b1b..2e0305cd 100644 --- a/services/klatchat_observer/controller.py +++ b/services/klatchat_observer/controller.py @@ -28,6 +28,8 @@ import json import re import time +import cachetools.func + from threading import Event, Timer import requests @@ -96,6 +98,7 @@ def __init__( self.server_url = self.sio_url self._klat_session_token = None self.klat_auth_credentials = config.get("KLAT_AUTH_CREDENTIALS", {}) + self.default_persona_llms = dict() self.connect_sio() self.register_consumer( name="neon_response", @@ -206,8 +209,7 @@ def get_recipient_from_prefix(cls, message: str) -> dict: break return callback - @staticmethod - def get_recipient_from_body(message: str) -> dict: + def get_recipient_from_body(self, message: str) -> dict: """ Gets recipients from message body @@ -215,7 +217,7 @@ def get_recipient_from_body(message: str) -> dict: :returns extracted recipient Example: - >>> assert ChatObserver.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}} + >>> assert self.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}} """ message = " " + message bot_mentioning_regexp = r"[\s]+@[a-zA-Z]+[\w]+" @@ -225,17 +227,24 @@ def get_recipient_from_body(message: str) -> dict: recipient = Recipients.CHATBOT_CONTROLLER else: recipient = Recipients.UNRESOLVED - return {"recipient": recipient, "context": {"requested_participants": bots}} + return { + "recipient": recipient, + "context": { + "requested_participants": [ + self.default_persona_llms.get(bot, bot) for bot in bots + ] + }, + } - @staticmethod - def get_recipient_from_bound_service(bound_service) -> dict: + def get_recipient_from_bound_service(self, bound_service) -> dict: """Gets recipient in case bounded service is received in data""" response = {} if bound_service.startswith("chatbots"): + bot = bound_service.split(".")[1].split(",") response = { "recipient": Recipients.CHATBOT_CONTROLLER, "context": { - "requested_participants": bound_service.split(".")[1].split(",") + "requested_participants": self.default_persona_llms.get(bot, bot) }, } elif bound_service.startswith("neon"): @@ -743,7 +752,12 @@ def on_subminds_state(self, body: dict): @create_mq_callback() def on_get_configured_personas(self, body: dict): - response_data = self._fetch_persona_api(body=body) + response_data = self._fetch_persona_api(user_id=body.get("user_id")) + response_data["items"] = [ + item + for item in response_data["items"] + if body["service_name"] in item["supported_llms"] + ] response_data.setdefault("context", {}).setdefault("mq", {}).setdefault( "message_id", body["message_id"] ) @@ -754,20 +768,27 @@ def on_get_configured_personas(self, body: dict): expiration=3000, ) - def _fetch_persona_api(self, body: dict) -> dict: - query_string = self._build_persona_api_query(body=body) + @cachetools.func.ttl_cache(ttl=2 * 60) + def _fetch_persona_api(self, user_id: str) -> dict: + query_string = self._build_persona_api_query(user_id=user_id) url = f"{self.server_url}/personas/list?{query_string}" try: response = self._fetch_klat_server(url=url) data = response.json() + self._refresh_default_persona_llms(data=data) except KlatAPIAuthorizationError: LOG.error(f"Failed to fetch personas from {url = }") data = {"items": []} return data - def _build_persona_api_query(self, body: dict) -> str: - url_query_params = f"llms={body['service_name']}&only_enabled=true" - if user_id := body.get("user_id"): + def _refresh_default_persona_llms(self, data): + for item in data["items"]: + if default_llm := item.get("default_llm"): + self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm + + def _build_persona_api_query(self, user_id: str) -> str: + url_query_params = f"only_enabled=true" + if user_id: url_query_params += f"&user_id={user_id}" return url_query_params