diff --git a/chat_server/app.py b/chat_server/app.py index 1b6bc292..2f290779 100644 --- a/chat_server/app.py +++ b/chat_server/app.py @@ -46,6 +46,7 @@ from .sio import sio from .blueprints import ( + admin as admin_blueprint, auth as auth_blueprint, chat as chat_blueprint, users as users_blueprint, @@ -95,6 +96,7 @@ async def log_requests(request: Request, call_next): LOG.error(f"rid={idem} received an exception {ex}") return None + chat_app.include_router(admin_blueprint.router) chat_app.include_router(auth_blueprint.router) chat_app.include_router(chat_blueprint.router) chat_app.include_router(users_blueprint.router) diff --git a/chat_server/blueprints/admin.py b/chat_server/blueprints/admin.py new file mode 100644 index 00000000..1fc394f8 --- /dev/null +++ b/chat_server/blueprints/admin.py @@ -0,0 +1,74 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter +from starlette.requests import Request + +from utils.logging_utils import LOG +from utils.http_utils import respond + +from chat_server.server_config import k8s_config +from chat_server.server_utils.auth import login_required +from chat_server.server_utils.k8s_utils import restart_deployment +from chat_server.server_utils.admin_utils import run_mq_validation + +router = APIRouter( + prefix="/admin", + responses={"404": {"description": "Unknown authorization endpoint"}}, +) + + +@router.post("/refresh/{service_name}") +@login_required(tmp_allowed=False, required_roles=["admin"]) +async def refresh_state( + request: Request, service_name: str, target_items: str | None = "" +): + """ + Refreshes state of the target + + :param request: Starlette Request Object + :param service_name: name of service to refresh + :param target_items: comma-separated list of items to refresh + + :returns JSON-formatted response from server + """ + target_items = [x for x in target_items.split(",") if x] + if service_name == "k8s": + if not k8s_config: + return respond("K8S Service Unavailable", 503) + deployments = target_items + if deployments == "*": + deployments = k8s_config.get("MANAGED_DEPLOYMENTS", []) + LOG.info(f"Restarting {deployments=!r}") + for deployment in deployments: + restart_deployment(deployment_name=deployment) + elif service_name == "mq": + run_mq_validation() + else: + return respond(f"Unknown refresh type: {service_name!r}", 404) + return respond("OK") diff --git a/chat_server/server_config.py b/chat_server/server_config.py index 5c3bea0b..ca03e2c0 100644 --- a/chat_server/server_config.py +++ b/chat_server/server_config.py @@ -29,8 +29,12 @@ import os from typing import Optional +from utils.logging_utils import LOG + from config import Configuration from chat_server.server_utils.sftp_utils import init_sftp_connector +from chat_server.server_utils.rmq_utils import RabbitMQAPI + from utils.logging_utils import LOG from utils.database_utils import DatabaseController @@ -86,3 +90,14 @@ def _init_db_controller(db_config: dict) -> Optional[DatabaseController]: LOG.info(f"App config: {app_config}") sftp_connector = init_sftp_connector(config=app_config.get("SFTP", {})) + +mq_api = None +mq_management_config = config.get("MQ_MANAGEMENT", {}) +if mq_management_url := mq_management_config.get("MQ_MANAGEMENT_URL"): + mq_api = RabbitMQAPI(url=mq_management_url) + mq_api.login( + username=mq_management_config["MQ_MANAGEMENT_LOGIN"], + password=mq_management_config["MQ_MANAGEMENT_PASSWORD"], + ) + +k8s_config = config.get("K8S_CONFIG", {}) diff --git a/chat_server/server_utils/admin_utils.py b/chat_server/server_utils/admin_utils.py new file mode 100644 index 00000000..937d00b9 --- /dev/null +++ b/chat_server/server_utils/admin_utils.py @@ -0,0 +1,52 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from chat_server.server_config import mq_api, mq_management_config, LOG + + +def run_mq_validation(): + if mq_api: + for vhost in mq_management_config.get("VHOSTS", []): + status = mq_api.add_vhost(vhost=vhost["name"]) + if not status.ok: + raise ConnectionError(f'Failed to add {vhost["name"]}, {status=}') + for user_creds in mq_management_config.get("USERS", []): + mq_api.add_user( + user=user_creds["name"], + password=user_creds["password"], + tags=user_creds.get("tags", ""), + ) + for user_vhost_permissions in mq_management_config.get( + "USER_VHOST_PERMISSIONS", [] + ): + mq_api.configure_vhost_user_permissions(**user_vhost_permissions) + else: + LOG.error("MQ API is unavailable") + + +if __name__ == "__main__": + run_mq_validation() diff --git a/chat_server/server_utils/auth.py b/chat_server/server_utils/auth.py index 4d21ef27..ce059d72 100644 --- a/chat_server/server_utils/auth.py +++ b/chat_server/server_utils/auth.py @@ -258,7 +258,10 @@ def refresh_session(payload: dict): def validate_session( - request: Union[str, Request], check_tmp: bool = False, sio_request: bool = False + request: Union[str, Request], + check_tmp: bool = False, + required_roles: list = None, + sio_request: bool = False, ) -> Tuple[str, int]: """ Check if session token contained in request is valid @@ -269,12 +272,20 @@ def validate_session( payload = jwt.decode( jwt=session, key=secret_key, algorithms=jwt_encryption_algo ) - if check_tmp: + should_check_user_data = check_tmp or required_roles + is_authorized = True + if should_check_user_data: from chat_server.server_utils.db_utils import DbUtils user = DbUtils.get_user(user_id=payload["sub"]) - if user.get("is_tmp"): - return "Permission denied", 403 + if check_tmp and user.get("is_tmp"): + is_authorized = False + elif required_roles and not any( + user_role in required_roles for user_role in user.get("roles", []) + ): + is_authorized = False + if not is_authorized: + return "Permission denied", 403 if (int(time()) - int(payload.get("creation_time", 0))) <= session_lifetime: return "OK", 200 return "Session Expired", 401 @@ -298,7 +309,9 @@ def outer(func): @wraps(func) async def wrapper(request: Request, *args, **kwargs): session_validation_output = validate_session( - request, check_tmp=not outer_kwargs.get("tmp_allowed") + request, + check_tmp=not outer_kwargs.get("tmp_allowed"), + required_roles=outer_kwargs.get("required_roles"), ) LOG.debug( f"(url={request.url}) Received session validation output: {session_validation_output}" diff --git a/chat_server/server_utils/cache_utils.py b/chat_server/server_utils/cache_utils.py index 75121adc..7acff20d 100644 --- a/chat_server/server_utils/cache_utils.py +++ b/chat_server/server_utils/cache_utils.py @@ -45,6 +45,7 @@ def get(cls, name: str, cache_type: Type = None, **kwargs): """ if not cls.__active_caches.get(name): if cache_type: + kwargs.setdefault("maxsize", 124) cls.__active_caches[name] = cache_type(**kwargs) else: raise KeyError(f"Missing cache instance under {name}") diff --git a/chat_server/server_utils/db_utils.py b/chat_server/server_utils/db_utils.py index f2bbc99d..45e29e9a 100644 --- a/chat_server/server_utils/db_utils.py +++ b/chat_server/server_utils/db_utils.py @@ -493,7 +493,7 @@ def save_translations(cls, translation_mapping: dict) -> Dict[str, List[str]]: filter_expression = {"_id": shout_id} cls.db_controller.exec_query( query=MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.SHOUTS, filters=filter_expression, data={"translations": {}}, @@ -546,7 +546,7 @@ def get_user_preferences(cls, user_id): if user and not user.get("preferences"): cls.db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.USERS, filters=MongoFilter(key="_id", value=user_id), data={"preferences": prefs}, @@ -570,7 +570,7 @@ def set_user_preferences(cls, user_id, preferences_mapping: dict): } cls.db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.USERS, filters=MongoFilter("_id", user_id), data=update_mapping, @@ -603,7 +603,7 @@ def save_tts_response( ) cls.db_controller.exec_query( query=MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.SHOUTS, filters=MongoFilter("_id", shout_id), data={f"audio.{lang}.{gender}": audio_file_name}, @@ -628,7 +628,7 @@ def save_stt_response(cls, shout_id, message_text: str, lang: str = "en"): try: cls.db_controller.exec_query( query=MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.SHOUTS, filters=MongoFilter("_id", shout_id), data={f"transcripts.{lang}": message_text}, diff --git a/chat_server/server_utils/k8s_utils.py b/chat_server/server_utils/k8s_utils.py new file mode 100644 index 00000000..2302b51f --- /dev/null +++ b/chat_server/server_utils/k8s_utils.py @@ -0,0 +1,80 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import os + +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from chat_server.server_config import k8s_config +from utils.logging_utils import LOG + +k8s_app_api = None +_k8s_default_namespace = "default" + +if _k8s_config_path := k8s_config.get("K8S_CONFIG_PATH"): + _k8s_default_namespace = ( + k8s_config.get("K8S_DEFAULT_NAMESPACE") or _k8s_default_namespace + ) + config.load_kube_config(_k8s_config_path) + + k8s_app_api = client.AppsV1Api() +else: + LOG.warning("K8S config is unset!") + + +def restart_deployment(deployment_name: str, namespace: str = _k8s_default_namespace): + """ + Restarts K8S deployment + :param deployment_name: name of the deployment to restart + :param namespace: name of the namespace + """ + if not k8s_app_api: + LOG.error( + f"Failed to restart {deployment_name=!r} ({namespace=!r}) - missing K8S configs" + ) + return -1 + now = datetime.datetime.utcnow() + now = str(now.isoformat() + "Z") + body = { + "spec": { + "template": { + "metadata": {"annotations": {"kubectl.kubernetes.io/restartedAt": now}} + } + } + } + try: + k8s_app_api.patch_namespaced_deployment( + deployment_name, namespace, body, pretty="true" + ) + except ApiException as e: + LOG.error( + "Exception when calling AppsV1Api->read_namespaced_deployment_status: %s\n" + % e + ) diff --git a/chat_server/server_utils/prompt_utils.py b/chat_server/server_utils/prompt_utils.py index 1f5c946b..2bea8587 100644 --- a/chat_server/server_utils/prompt_utils.py +++ b/chat_server/server_utils/prompt_utils.py @@ -81,7 +81,7 @@ def handle_prompt_message(message: dict) -> bool: } db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.PROMPTS, filters=MongoFilter(key="_id", value=prompt_id), **data_kwargs, @@ -128,7 +128,7 @@ def handle_prompt_message(message: dict) -> bool: } db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.PROMPTS, filters=MongoFilter(key="_id", value=prompt_id), **data_kwargs, diff --git a/chat_server/server_utils/rmq_utils.py b/chat_server/server_utils/rmq_utils.py new file mode 100644 index 00000000..8dc88b1f --- /dev/null +++ b/chat_server/server_utils/rmq_utils.py @@ -0,0 +1,174 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import requests + +from urllib.parse import quote_plus +from requests.auth import HTTPBasicAuth + + +# TODO: use this package from neon_diana_utils once its dependencies won't cause conficts +class RabbitMQAPI: + def __init__(self, url: str, verify_ssl: bool = False): + """ + Creates an object used to interface with a RabbitMQ server + :param url: Management URL (usually IP:port) + """ + self._verify_ssl = verify_ssl + self.console_url = url + self._username = None + self._password = None + + def login(self, username: str, password: str): + """ + Sets internal username/password parameters used to generate HTTP auth + :param username: user to authenticate as + :param password: plaintext password to authenticate with + """ + self._username = username + self._password = password + # TODO: Check auth and return DM + + @property + def auth(self): + """ + HTTPBasicAuth object to include with requests. + """ + return HTTPBasicAuth(self._username, self._password) + + def add_vhost(self, vhost: str) -> bool: + """ + Add a vhost to the server + :param vhost: vhost to add + :return: True if request was successful + """ + status = requests.put( + f"{self.console_url}/api/vhosts/{quote_plus(vhost)}", + auth=self.auth, + verify=self._verify_ssl, + ) + return status + + def add_user(self, user: str, password: str, tags: str = "") -> bool: + """ + Add a user to the server + :param user: username to add + :param password: password for user + :param tags: comma-delimited list of tags to assign to new user + :return: True if request was successful + """ + tags = tags or "" + body = {"password": password, "tags": tags} + status = requests.put( + f"{self.console_url}/api/users/{quote_plus(user)}", + data=json.dumps(body), + auth=self.auth, + verify=self._verify_ssl, + ) + return status.ok + + def delete_user(self, user: str) -> bool: + """ + Delete a user from the server + :param user: username to remove + """ + status = requests.delete( + f"{self.console_url}/api/users/{quote_plus(user)}", + auth=self.auth, + verify=self._verify_ssl, + ) + return status.ok + + def configure_vhost_user_permissions( + self, + vhost: str, + user: str, + configure: str = ".*", + write: str = ".*", + read: str = ".*", + ) -> bool: + """ + Configure user's access to vhost. See RabbitMQ docs: + https://www.rabbitmq.com/access-control.html#authorisation + :param vhost: vhost to set/modify permissions for + :param user: user to set/modify permissions of + :param configure: regex configure permissions + :param write: regex write permissions + :param read: regex read permissions + :return: True if request was successful + """ + url = ( + f"{self.console_url}/api/permissions/{quote_plus(vhost)}/" + f"{quote_plus(user)}" + ) + body = {"configure": configure, "write": write, "read": read} + status = requests.put( + url, data=json.dumps(body), auth=self.auth, verify=self._verify_ssl + ) + return status.ok + + def get_definitions(self): + """ + Get the server definitions for RabbitMQ; these are used to persist + configuration between container restarts + """ + resp = requests.get( + f"{self.console_url}/api/definitions", + auth=self.auth, + verify=self._verify_ssl, + ) + data = json.loads(resp.content) + return data + + def create_default_users(self, users: list) -> dict: + """ + Creates the passed list of users with random passwords and returns a + dict of users to passwords + :param users: list of usernames to create + :return: Dict of created usernames and associated passwords + """ + import secrets + + credentials = dict() + for user in users: + passwd = secrets.token_urlsafe(32) + credentials[user] = passwd + self.add_user(user, passwd) + return credentials + + def configure_admin_account(self, username: str, password: str) -> bool: + """ + Configures an administrator with the passed credentials and removes + the default account + :param username: New administrator's username + :param password: New administrator's password + :return: True if action was successful + """ + create = self.add_user(username, password, "administrator") + self.login(username, password) + delete = self.delete_user("guest") + return create and delete diff --git a/chat_server/server_utils/user_utils.py b/chat_server/server_utils/user_utils.py index f14f0d42..b3fef4c3 100644 --- a/chat_server/server_utils/user_utils.py +++ b/chat_server/server_utils/user_utils.py @@ -135,7 +135,7 @@ def get_bot_data( elif not bot_data.get("is_bot") == "1": db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.USERS, filters=MongoFilter("_id", bot_data["_id"]), data={"is_bot": "1"}, diff --git a/chat_server/sio.py b/chat_server/sio.py index daf21a54..c8cd17bb 100644 --- a/chat_server/sio.py +++ b/chat_server/sio.py @@ -261,7 +261,7 @@ async def user_message(sid, data): ) db_controller.exec_query( query=MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.CHATS, filters=filter_expression, data={"chat_flow": new_shout_data["_id"]}, @@ -363,7 +363,7 @@ async def prompt_completed(sid, data): try: db_controller.exec_query( MongoQuery( - command=MongoCommands.UPDATE, + command=MongoCommands.UPDATE_MANY, document=MongoDocuments.PROMPTS, filters=MongoFilter(key="_id", value=prompt_id), data=prompt_summary_agg, @@ -458,9 +458,9 @@ async def request_translate(sid, data): "sid": sid, "input_type": input_type, } - CacheFactory.get("translation_cache", cache_type=LRUCache).put( - key=request_id, value=caching_instance - ) + CacheFactory.get("translation_cache", cache_type=LRUCache)[ + request_id + ] = caching_instance await sio.emit( "request_neon_translations", data={"request_id": request_id, "data": missing_translations}, @@ -731,6 +731,23 @@ async def request_stt(sid, data): await sio.emit("get_stt", data=formatted_data) +@sio.event +# @login_required +async def broadcast(sid, data): + """Forwards received broadcast message from client""" + # TODO: introduce certification mechanism to forward messages only from trusted entities + msg_type = data.pop("msg_type", None) + msg_receivers = data.pop("to", None) + if not msg_type: + LOG.error(f'data={data} skipped - no "msg_type" provided') + if msg_type: + await sio.emit( + msg_type, + data=data, + to=msg_receivers, + ) + + async def emit_error( message: str, context: Optional[dict] = None, sids: Optional[List[str]] = None ): diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0b1cf71f..addf431a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,6 +6,7 @@ fastapi-socketio==0.0.10 httpx==0.25.0 # required by FastAPI Jinja2==3.1.2 jsbeautifier==1.14.7 +kubernetes==28.1.0 neon-mq-connector~=0.7 neon-sftp~=0.1 ovos_config==0.0.10 diff --git a/services/klatchat_observer/controller.py b/services/klatchat_observer/controller.py index eed0ba62..8129e0b3 100644 --- a/services/klatchat_observer/controller.py +++ b/services/klatchat_observer/controller.py @@ -34,6 +34,7 @@ from enum import Enum +from neon_mq_connector.utils import retry from neon_mq_connector.utils.rabbit_utils import create_mq_callback from neon_mq_connector.connector import MQConnector from utils.logging_utils import LOG @@ -87,13 +88,7 @@ def __init__( } self._sio = None self.sio_url = config["SIO_URL"] - try: - self.connect_sio() - except Exception as ex: - err = f"Failed to connect Socket IO at {self.sio_url} due to exception={str(ex)}, observing will not be run" - LOG.warning(err) - if not self.testing_mode: - raise ConnectionError(err) + self.connect_sio() self.register_consumer( name="neon_response", vhost=self.get_vhost("neon_api"), @@ -157,6 +152,13 @@ def __init__( callback=self.on_neon_translations_response, on_error=self.default_error_handler, ) + self.register_subscriber( + name="subminds_state_receiver", + vhost=self.get_vhost("chatbots"), + exchange="subminds_state", + callback=self.on_subminds_state, + on_error=self.default_error_handler, + ) @classmethod def get_recipient_from_prefix(cls, message: str) -> dict: @@ -293,16 +295,14 @@ def register_sio_handlers(self): "request_neon_translations", handler=self.request_neon_translations ) - def connect_sio(self, refresh=False): + @retry(use_self=True) + def connect_sio(self): """ Method for establishing connection with Socket IO server - - :param refresh: To refresh an existing instance """ - if not self._sio or refresh: - self._sio = socketio.Client() - self._sio.connect(url=self.sio_url) - self.register_sio_handlers() + self._sio = socketio.Client() + self._sio.connect(url=self.sio_url) + self.register_sio_handlers() @property def sio(self): @@ -710,3 +710,10 @@ def on_tts_response(self, body: dict): """Handles receiving TTS response""" LOG.info(f"Received TTS Response: {body}") self.sio.emit("tts_response", data=body) + + @create_mq_callback() + def on_subminds_state(self, body: dict): + """Handles receiving subminds state message""" + LOG.info(f"Received submind state: {body}") + body["msg_type"] = "subminds_state" + self.sio.emit("broadcast", data=body) diff --git a/utils/database_utils/mongo_utils/structures.py b/utils/database_utils/mongo_utils/structures.py index 043c8e7f..9b5b2338 100644 --- a/utils/database_utils/mongo_utils/structures.py +++ b/utils/database_utils/mongo_utils/structures.py @@ -52,6 +52,8 @@ class MongoCommands(Enum): DELETE_MANY = "delete_many" # Update operation UPDATE = "update_many" + UPDATE_MANY = "update_many" + UPDATE_ONE = "update_one" class MongoDocuments(Enum): @@ -135,7 +137,10 @@ def build_filters(self): def build_setter(self) -> dict: """Builds setter for Mongo Query""" res = None - if self.command.value == MongoCommands.UPDATE.value: + if self.command.value in ( + MongoCommands.UPDATE_MANY.value, + MongoCommands.UPDATE_ONE.value, + ): res = {f"${self.data_action.lower()}": self.data} elif self.command.value in ( MongoCommands.INSERT_ONE.value, diff --git a/utils/database_utils/mongodb_connector.py b/utils/database_utils/mongodb_connector.py index 27483a74..d40cd5e8 100644 --- a/utils/database_utils/mongodb_connector.py +++ b/utils/database_utils/mongodb_connector.py @@ -36,6 +36,7 @@ class MongoDBConnector(DatabaseConnector): """Connector implementing interface for interaction with Mongo DB API""" + mongo_recognised_commands = set(cmd.value for cmd in MongoCommands) @property diff --git a/version.py b/version.py index 4bd3e03c..37c71177 100644 --- a/version.py +++ b/version.py @@ -26,4 +26,4 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = "0.4.2a1" +__version__ = "0.4.4"