Skip to content

Commit

Permalink
Added support for default naming of personas
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Mar 21, 2024
1 parent b489d20 commit c5a447c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
10 changes: 8 additions & 2 deletions chat_server/server_utils/models/personas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ def persona_id(self):


class AddPersonaModel(Persona):
supported_llms: list[str] = Field(examples=[["chatgpt", "llama", "fastchat"]])
supported_llms: list[str] = Field(
examples=[["chat_gpt", "llama", "fastchat"]], default=[]
)
default_llm: str | None = Field(examples=["chat_gpt"], default=None)
description: str = Field(examples=["I am the doctor. I am helping people."])
enabled: bool = False


class SetPersonaModel(Persona):
supported_llms: list[str] = Field(examples=[["chatgpt", "llama", "fastchat"]])
supported_llms: list[str] = Field(
examples=[["chat_gpt", "llama", "fastchat"]], default=[]
)
default_llm: str | None = Field(examples=["chat_gpt"], default=None)
description: str = Field(examples=["I am the doctor. I am helping people."])


Expand Down
47 changes: 34 additions & 13 deletions services/klatchat_observer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import json
import re
import time
import cachetools.func

from threading import Event, Timer

import requests
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
self.server_url = self.sio_url
self._klat_session_token = None
self.klat_auth_credentials = config.get("KLAT_AUTH_CREDENTIALS", {})
self.default_persona_llms = dict()
self.connect_sio()
self.register_consumer(
name="neon_response",
Expand Down Expand Up @@ -206,16 +209,15 @@ def get_recipient_from_prefix(cls, message: str) -> dict:
break
return callback

@staticmethod
def get_recipient_from_body(message: str) -> dict:
def get_recipient_from_body(self, message: str) -> dict:
"""
Gets recipients from message body
:param message: user's message
:returns extracted recipient
Example:
>>> assert ChatObserver.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}}
>>> assert self.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}}
"""
message = " " + message
bot_mentioning_regexp = r"[\s]+@[a-zA-Z]+[\w]+"
Expand All @@ -225,17 +227,24 @@ def get_recipient_from_body(message: str) -> dict:
recipient = Recipients.CHATBOT_CONTROLLER
else:
recipient = Recipients.UNRESOLVED
return {"recipient": recipient, "context": {"requested_participants": bots}}
return {
"recipient": recipient,
"context": {
"requested_participants": [
self.default_persona_llms.get(bot, bot) for bot in bots
]
},
}

@staticmethod
def get_recipient_from_bound_service(bound_service) -> dict:
def get_recipient_from_bound_service(self, bound_service) -> dict:
"""Gets recipient in case bounded service is received in data"""
response = {}
if bound_service.startswith("chatbots"):
bot = bound_service.split(".")[1].split(",")
response = {
"recipient": Recipients.CHATBOT_CONTROLLER,
"context": {
"requested_participants": bound_service.split(".")[1].split(",")
"requested_participants": self.default_persona_llms.get(bot, bot)
},
}
elif bound_service.startswith("neon"):
Expand Down Expand Up @@ -743,7 +752,12 @@ def on_subminds_state(self, body: dict):

@create_mq_callback()
def on_get_configured_personas(self, body: dict):
response_data = self._fetch_persona_api(body=body)
response_data = self._fetch_persona_api(user_id=body.get("user_id"))
response_data["items"] = [
item
for item in response_data["items"]
if body["service_name"] in item["supported_llms"]
]
response_data.setdefault("context", {}).setdefault("mq", {}).setdefault(
"message_id", body["message_id"]
)
Expand All @@ -754,20 +768,27 @@ def on_get_configured_personas(self, body: dict):
expiration=3000,
)

def _fetch_persona_api(self, body: dict) -> dict:
query_string = self._build_persona_api_query(body=body)
@cachetools.func.ttl_cache(ttl=2 * 60)
def _fetch_persona_api(self, user_id: str) -> dict:
query_string = self._build_persona_api_query(user_id=user_id)
url = f"{self.server_url}/personas/list?{query_string}"
try:
response = self._fetch_klat_server(url=url)
data = response.json()
self._refresh_default_persona_llms(data=data)
except KlatAPIAuthorizationError:
LOG.error(f"Failed to fetch personas from {url = }")
data = {"items": []}
return data

def _build_persona_api_query(self, body: dict) -> str:
url_query_params = f"llms={body['service_name']}&only_enabled=true"
if user_id := body.get("user_id"):
def _refresh_default_persona_llms(self, data):
for item in data["items"]:
if default_llm := item.get("default_llm"):
self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm

def _build_persona_api_query(self, user_id: str) -> str:
url_query_params = f"only_enabled=true"
if user_id:
url_query_params += f"&user_id={user_id}"
return url_query_params

Expand Down

0 comments on commit c5a447c

Please sign in to comment.