From 70ac7d47053f2e25f1b3b8cd554d338f417d2d3a Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Mon, 29 Jan 2024 20:23:35 +0100 Subject: [PATCH 1/5] Refactored Mongo DB API to be more granular, optimized and generified queries in place --- chat_server/blueprints/admin.py | 6 +- chat_server/blueprints/auth.py | 22 +- chat_server/blueprints/chat.py | 33 +- chat_server/blueprints/files_api.py | 24 +- chat_server/blueprints/preferences.py | 5 +- chat_server/blueprints/users.py | 13 +- chat_server/server_config.py | 33 +- chat_server/server_utils/auth.py | 64 +- chat_server/server_utils/db_utils.py | 680 ------------------ chat_server/server_utils/enums.py | 2 +- chat_server/server_utils/prompt_utils.py | 142 ---- chat_server/server_utils/user_utils.py | 84 --- chat_server/services/popularity_counter.py | 27 +- chat_server/sio.py | 178 ++--- chat_server/tests/test_sio.py | 3 +- config.py | 3 - migration_scripts/shouts.py | 8 +- utils/__init__.py | 27 + .../mongo_utils/queries/__init__.py | 27 + .../mongo_utils/queries/constants.py | 26 + .../mongo_utils/queries/dao/__init__.py | 27 + .../mongo_utils/queries/dao/abc.py | 176 +++++ .../mongo_utils/queries/dao/chats.py | 128 ++++ .../mongo_utils/queries/dao/prompts.py | 191 +++++ .../mongo_utils/queries/dao/shouts.py | 197 +++++ .../mongo_utils/queries/dao/users.py | 217 ++++++ .../mongo_utils/queries/mongo_queries.py | 219 ++++++ .../mongo_utils/queries/wrapper.py | 68 ++ 28 files changed, 1450 insertions(+), 1180 deletions(-) delete mode 100644 chat_server/server_utils/db_utils.py delete mode 100644 chat_server/server_utils/prompt_utils.py create mode 100644 utils/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/constants.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/abc.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/chats.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/prompts.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/shouts.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/users.py create mode 100644 utils/database_utils/mongo_utils/queries/mongo_queries.py create mode 100644 utils/database_utils/mongo_utils/queries/wrapper.py diff --git a/chat_server/blueprints/admin.py b/chat_server/blueprints/admin.py index b0de348d..63dcb0a2 100644 --- a/chat_server/blueprints/admin.py +++ b/chat_server/blueprints/admin.py @@ -30,11 +30,11 @@ from starlette.requests import Request from starlette.responses import JSONResponse -from chat_server.server_utils.db_utils import DbUtils +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG from utils.http_utils import respond -from chat_server.server_config import k8s_config, db_controller +from chat_server.server_config import k8s_config from chat_server.server_utils.auth import login_required from chat_server.server_utils.k8s_utils import restart_deployment from chat_server.server_utils.admin_utils import run_mq_validation @@ -79,7 +79,7 @@ async def refresh_state( @router.get("/chats/list") @login_required(tmp_allowed=False, required_roles=["admin"]) async def chats_overview(request: Request, search_str: str = ""): - conversations_data = DbUtils.get_conversation_data( + conversations_data = MongoDocumentsAPI.CHATS.get_conversation_data( search_str=search_str, limit=100, allow_regex_search=True, diff --git a/chat_server/blueprints/auth.py b/chat_server/blueprints/auth.py index ed5ca8c5..f86da171 100644 --- a/chat_server/blueprints/auth.py +++ b/chat_server/blueprints/auth.py @@ -31,13 +31,13 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import JSONResponse -from chat_server.server_config import db_controller from utils.common import get_hash, generate_uuid from chat_server.server_utils.auth import ( check_password_strength, get_current_user_data, generate_session_token, ) +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond router = APIRouter( @@ -64,13 +64,7 @@ async def signup( :returns JSON response with status corresponding to the new user creation status, sets session cookies if creation is successful """ - existing_user = db_controller.exec_query( - query={ - "command": "find_one", - "document": "users", - "data": {"nickname": nickname}, - } - ) + existing_user = MongoDocumentsAPI.USERS.get_user(nickname=nickname) if existing_user: return respond("Nickname is already in use", 400) password_check = check_password_strength(password) @@ -85,9 +79,7 @@ async def signup( date_created=int(time()), is_tmp=False, ) - db_controller.exec_query( - query=dict(document="users", command="insert_one", data=new_user_record) - ) + MongoDocumentsAPI.USERS.add_item(data=new_user_record) token = generate_session_token(user_id=new_user_record["_id"]) @@ -104,13 +96,7 @@ async def login(username: str = Form(...), password: str = Form(...)): :returns JSON response with status corresponding to authorization status, sets session cookie with response """ - user = db_controller.exec_query( - query={ - "command": "find_one", - "document": "users", - "data": {"nickname": username}, - } - ) + user = MongoDocumentsAPI.USERS.get_user(nickname=username) if not user or user.get("is_tmp", False): return respond("Invalid username or password", 400) db_password = user["password"] diff --git a/chat_server/blueprints/chat.py b/chat_server/blueprints/chat.py index 774f11c9..2dffc683 100644 --- a/chat_server/blueprints/chat.py +++ b/chat_server/blueprints/chat.py @@ -32,17 +32,12 @@ from fastapi.responses import JSONResponse from chat_server.constants.conversations import ConversationSkins -from chat_server.server_config import db_controller from chat_server.server_utils.auth import login_required from chat_server.server_utils.conversation_utils import build_message_json -from chat_server.server_utils.db_utils import ( - DbUtils, - MongoQuery, - MongoCommands, - MongoDocuments, -) from chat_server.services.popularity_counter import PopularityCounter from utils.common import generate_uuid +from utils.database_utils.mongo_utils.queries.mongo_queries import fetch_message_data +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG @@ -56,7 +51,7 @@ @login_required async def new_conversation( request: Request, - conversation_id: str = Form(""), + conversation_id: str = Form(""), # DEPRECATED conversation_name: str = Form(...), is_private: str = Form(False), bound_service: str = Form(""), @@ -65,7 +60,7 @@ async def new_conversation( Creates new conversation from provided conversation data :param request: Starlette Request object - :param conversation_id: new conversation id (optional) + :param conversation_id: new conversation id (DEPRECATED) :param conversation_name: new conversation name (optional) :param is_private: if new conversation should be private (defaults to False) :param bound_service: name of the bound service (ignored if empty value) @@ -73,12 +68,12 @@ async def new_conversation( :returns JSON response with new conversation data if added, 401 error message otherwise """ - conversation_data = DbUtils.get_conversation_data( - search_str=[conversation_id, conversation_name] + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=[conversation_id, conversation_name], ) if conversation_data: return respond(f'Conversation "{conversation_name}" already exists', 400) - cid = conversation_id or generate_uuid() + cid = generate_uuid() request_data_dict = { "_id": cid, "conversation_name": conversation_name, @@ -86,13 +81,7 @@ async def new_conversation( "bound_service": bound_service, "created_on": int(time()), } - db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.INSERT_ONE, - document=MongoDocuments.CHATS, - data=request_data_dict, - ) - ) + MongoDocumentsAPI.CHATS.add_item(data=request_data_dict) PopularityCounter.add_new_chat(cid=cid, name=conversation_name) return JSONResponse(content=request_data_dict) @@ -119,13 +108,15 @@ async def get_matching_conversation( :returns conversation data if found, 401 error code otherwise """ - conversation_data = DbUtils.get_conversation_data(search_str=search_str) + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=search_str + ) if not conversation_data: return respond(f'No conversation matching = "{search_str}"', 404) message_data = ( - DbUtils.fetch_skin_message_data( + fetch_message_data( skin=skin, conversation_data=conversation_data, start_idx=chat_history_from, diff --git a/chat_server/blueprints/files_api.py b/chat_server/blueprints/files_api.py index 3e4f9a99..47be5778 100644 --- a/chat_server/blueprints/files_api.py +++ b/chat_server/blueprints/files_api.py @@ -31,10 +31,9 @@ from starlette.requests import Request from starlette.responses import JSONResponse -from chat_server.server_config import db_controller from chat_server.server_utils.auth import login_required -from chat_server.server_utils.db_utils import DbUtils from chat_server.server_utils.http_utils import get_file_response, save_file +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG @@ -50,11 +49,11 @@ async def get_audio_message( message_id: str, ): """Gets file based on the name""" - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if matching_shouts and matching_shouts[0].get("is_audio", "0") == "1": + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if matching_shout and matching_shout.get("is_audio", "0") == "1": LOG.info(f"Fetching audio for message_id={message_id}") return get_file_response( - matching_shouts[0]["message_text"], + matching_shout["message_text"], location_prefix="audio", media_type="audio/wav", ) @@ -70,12 +69,7 @@ async def get_avatar(user_id: str): :param user_id: target user id """ LOG.debug(f"Getting avatar of user id: {user_id}") - user_data = ( - db_controller.exec_query( - query={"document": "users", "command": "find_one", "data": {"_id": user_id}} - ) - or {} - ) + user_data = MongoDocumentsAPI.USERS.get_user(user_id=user_id) or {} if user_data.get("avatar", None): num_attempts = 0 try: @@ -101,13 +95,11 @@ async def get_message_attachment(request: Request, msg_id: str, filename: str): :param filename: name of the file to get """ LOG.debug(f"{msg_id} - {filename}") - message_files = db_controller.exec_query( - query={"document": "shouts", "command": "find_one", "data": {"_id": msg_id}} - ) - if message_files: + shout_data = MongoDocumentsAPI.SHOUTS.get_item(item_id=msg_id) + if shout_data: attachment_data = [ attachment - for attachment in message_files["attachments"] + for attachment in shout_data["attachments"] if attachment["name"] == filename ][0] media_type = attachment_data["mime"] diff --git a/chat_server/blueprints/preferences.py b/chat_server/blueprints/preferences.py index b19de1a5..cc4105de 100644 --- a/chat_server/blueprints/preferences.py +++ b/chat_server/blueprints/preferences.py @@ -27,9 +27,8 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from fastapi import APIRouter, Request, Form -from chat_server.server_config import db_controller from chat_server.server_utils.auth import get_current_user, login_required -from chat_server.server_utils.db_utils import DbUtils +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG @@ -50,7 +49,7 @@ async def update_language( except Exception as ex: LOG.error(ex) return respond(f"Failed to update language of {cid}/{input_type} to {lang}") - DbUtils.set_user_preferences( + MongoDocumentsAPI.USERS.set_preferences( user_id=current_user_id, preferences_mapping={f"chat_language_mapping.{cid}.{input_type}": lang}, ) diff --git a/chat_server/blueprints/users.py b/chat_server/blueprints/users.py index 2208a451..0a1833d7 100644 --- a/chat_server/blueprints/users.py +++ b/chat_server/blueprints/users.py @@ -38,9 +38,9 @@ get_current_user_data, login_required, ) -from chat_server.server_utils.db_utils import DbUtils from chat_server.server_utils.http_utils import save_file from utils.common import get_hash +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG @@ -69,9 +69,7 @@ async def get_user( """ session_token = "" if user_id: - user = db_controller.exec_query( - query={"document": "users", "command": "find_one", "data": {"_id": user_id}} - ) + user = MongoDocumentsAPI.USERS.get_user(user_id=user_id) user.pop("password", None) user.pop("date_created", None) user.pop("tokens", None) @@ -109,9 +107,8 @@ async def fetch_received_user_ids( if nicknames: filter_data["nickname"] = {"$in": nicknames.split(",")} - users = db_controller.exec_query( - query={"document": "users", "command": "find", "data": filter_data}, - as_cursor=False, + users = MongoDocumentsAPI.USERS.list_items( + filters=filter_data, result_as_cursor=False ) for user in users: user.pop("password", None) @@ -209,7 +206,7 @@ async def update_settings( """ user = get_current_user(request=request) preferences_mapping = {"minify_messages": minify_messages} - DbUtils.set_user_preferences( + MongoDocumentsAPI.USERS.set_preferences( user_id=user["_id"], preferences_mapping=preferences_mapping ) return respond(msg="OK") diff --git a/chat_server/server_config.py b/chat_server/server_config.py index b852c969..57dd1750 100644 --- a/chat_server/server_config.py +++ b/chat_server/server_config.py @@ -29,7 +29,6 @@ import os from typing import Optional -from utils.logging_utils import LOG from config import Configuration from chat_server.server_utils.sftp_utils import init_sftp_connector @@ -37,17 +36,17 @@ from utils.logging_utils import LOG from utils.database_utils import DatabaseController +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI -server_config_path = os.path.expanduser(os.environ.get( - "CHATSERVER_CONFIG", "~/.local/share/neon/credentials.json" -)) -database_config_path = os.path.expanduser(os.environ.get( - "DATABASE_CONFIG", "~/.local/share/neon/credentials.json" -)) +server_config_path = os.path.expanduser( + os.environ.get("CHATSERVER_CONFIG", "~/.local/share/neon/credentials.json") +) +database_config_path = os.path.expanduser( + os.environ.get("DATABASE_CONFIG", "~/.local/share/neon/credentials.json") +) def _init_db_controller(db_config: dict) -> Optional[DatabaseController]: - from chat_server.server_utils.db_utils import DbUtils # Determine configured database dialect dialect = db_config.pop("dialect", "mongo") @@ -57,42 +56,40 @@ def _init_db_controller(db_config: dict) -> Optional[DatabaseController]: db_controller = DatabaseController(config_data=db_config) db_controller.attach_connector(dialect=dialect) db_controller.connect() + return db_controller except Exception as e: LOG.exception(f"DatabaseController init failed: {e}") return None - # Initialize convenience class - DbUtils.init(db_controller) - return db_controller - if os.path.isfile(server_config_path) or os.path.isfile(database_config_path): LOG.warning(f"Using legacy configuration at {server_config_path}") LOG.warning(f"Using legacy configuration at {database_config_path}") LOG.info(f"KLAT_ENV : {Configuration.KLAT_ENV}") - config = Configuration(from_files=[server_config_path, - database_config_path]) + config = Configuration(from_files=[server_config_path, database_config_path]) app_config = config.get("CHAT_SERVER", {}).get(Configuration.KLAT_ENV, {}) db_controller = config.get_db_controller(name="pyklatchat_3333") else: # ovos-config has built-in mechanisms for loading configuration files based # on envvars, so the configuration structure is simplified from ovos_config.config import Configuration + config = Configuration() app_config = config.get("CHAT_SERVER") or dict() env_spec = os.environ.get("KLAT_ENV") if env_spec and app_config.get(env_spec): LOG.warning("Legacy configuration handling KLAT_ENV envvar") app_config = app_config.get(env_spec) - db_controller = _init_db_controller(app_config.get("connection_properties", - config.get( - "DATABASE_CONFIG", - {}))) + db_controller = _init_db_controller( + app_config.get("connection_properties", config.get("DATABASE_CONFIG", {})) + ) LOG.info(f"App config: {app_config}") sftp_connector = init_sftp_connector(config=app_config.get("SFTP", {})) +MongoDocumentsAPI.init(db_controller=db_controller, sftp_connector=sftp_connector) + mq_api = None mq_management_config = config.get("MQ_MANAGEMENT", {}) if mq_management_url := mq_management_config.get("MQ_MANAGEMENT_URL"): diff --git a/chat_server/server_utils/auth.py b/chat_server/server_utils/auth.py index ce059d72..e3b12f59 100644 --- a/chat_server/server_utils/auth.py +++ b/chat_server/server_utils/auth.py @@ -33,23 +33,20 @@ from time import time from fastapi import Request + +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG -from chat_server.constants.users import UserPatterns -from chat_server.server_config import db_controller, app_config -from chat_server.server_utils.db_utils import DbUtils -from utils.common import generate_uuid +from chat_server.server_config import app_config from utils.http_utils import respond cookies_config = app_config.get("COOKIES", {}) secret_key = cookies_config.get("SECRET", None) - session_lifetime = int(cookies_config.get("LIFETIME", 60 * 60)) session_refresh_rate = int(cookies_config.get("REFRESH_RATE", 5 * 60)) - jwt_encryption_algo = cookies_config.get("JWT_ALGO", "HS256") - AUTHORIZATION_HEADER = "Authorization" @@ -132,24 +129,11 @@ def create_unauthorized_user( :returns: generated UserData """ - from chat_server.server_utils.user_utils import create_from_pattern - - guest_nickname = f"guest_{generate_uuid(length=8)}" - - if nano_token: - new_user = create_from_pattern( - source=UserPatterns.GUEST_NANO, - override_defaults=dict(nickname=guest_nickname, tokens=[nano_token]), - ) - else: - new_user = create_from_pattern( - source=UserPatterns.GUEST, override_defaults=dict(nickname=guest_nickname) - ) - db_controller.exec_query( - query={"document": "users", "command": "insert_one", "data": new_user} - ) - token = generate_session_token(user_id=new_user["_id"]) if authorize else "" - LOG.debug(f"Created new user with name {new_user['nickname']}") + new_user = MongoDocumentsAPI.USERS.create_guest(nano_token=nano_token) + token = "" + if authorize: + token = generate_session_token(user_id=new_user["_id"]) + LOG.debug(f"Created new user with name {new_user['nickname']}") return UserData(user=new_user, session=token) @@ -172,12 +156,12 @@ def get_current_user_data( user_data: UserData = None if not force_tmp: if nano_token: - user = db_controller.exec_query( - query={ - "command": "find_one", - "document": "users", - "data": {"tokens": {"$all": [nano_token]}}, - } + user = MongoDocumentsAPI.USERS.get_item( + filters=MongoFilter( + key="tokens", + value=[nano_token], + logical_operator=MongoLogicalOperators.ALL, + ) ) if not user: LOG.info("Creating new user for nano agent") @@ -198,14 +182,8 @@ def get_current_user_data( int(current_timestamp) - int(payload.get("creation_time", 0)) ) <= session_lifetime: user_id = payload["sub"] - user = DbUtils.get_user(user_id=user_id) + user = MongoDocumentsAPI.USERS.get_user(user_id=user_id) LOG.info(f"Fetched user data: {user}") - user["preferences"] = DbUtils.get_user_preferences( - user_id=user_id - ) - LOG.info( - f'Fetched user preferences data: {user["preferences"]}' - ) if not user: LOG.info( f'{payload["sub"]} is not found among users, setting temporal user credentials' @@ -219,8 +197,10 @@ def get_current_user_data( LOG.info("Session was refreshed") user_data = UserData(user=user, session=session) except BaseException as ex: - LOG.exception(f"Problem resolving current user: {ex}\n" - f"setting tmp user credentials") + LOG.exception( + f"Problem resolving current user: {ex}\n" + f"setting tmp user credentials" + ) if not user_data: LOG.debug("Creating temp user") user_data = create_unauthorized_user() @@ -275,9 +255,7 @@ def validate_session( should_check_user_data = check_tmp or required_roles is_authorized = True if should_check_user_data: - from chat_server.server_utils.db_utils import DbUtils - - user = DbUtils.get_user(user_id=payload["sub"]) + user = MongoDocumentsAPI.USERS.get_user(user_id=payload["sub"]) if check_tmp and user.get("is_tmp"): is_authorized = False elif required_roles and not any( diff --git a/chat_server/server_utils/db_utils.py b/chat_server/server_utils/db_utils.py deleted file mode 100644 index bf319811..00000000 --- a/chat_server/server_utils/db_utils.py +++ /dev/null @@ -1,680 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2022 Neongecko.com Inc. -# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, -# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo -# BSD-3 License -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import re -from typing import List, Tuple, Union, Dict - -import pymongo -from bson import ObjectId -from pymongo import UpdateOne - -from chat_server.constants.conversations import ConversationSkins -from chat_server.constants.users import UserPatterns -from chat_server.server_utils.factory_utils import Singleton -from chat_server.server_utils.user_utils import create_from_pattern -from utils.common import buffer_to_base64 -from utils.database_utils.mongo_utils import * -from utils.logging_utils import LOG - - -class DbUtils(metaclass=Singleton): - """Singleton DB Utils class for convenience""" - - db_controller = None - - @classmethod - def init(cls, db_controller): - """Inits Singleton with specified database controller""" - cls.db_controller = db_controller - - @classmethod - def get_user(cls, user_id=None, nickname=None) -> Union[dict, None]: - """ - Gets user data based on provided params - :param user_id: target user id - :param nickname: target user nickname - """ - if not any( - x - for x in ( - user_id, - nickname, - ) - ): - LOG.warning("Neither user_id nor nickname was provided") - return - filter_data = {} - if user_id: - filter_data["_id"] = user_id - if nickname: - filter_data["nickname"] = nickname - return cls.db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ONE, - document=MongoDocuments.USERS, - filters=filter_data, - ) - ) - - @classmethod - def list_items( - cls, - document: MongoDocuments, - source_set: list, - key: str = "id", - value_keys: list = None, - ) -> dict: - """ - Lists items under provided document belonging to source set of provided column values - - :param document: source document to query - :param key: document's key to check - :param source_set: list of :param key values to check - :param value_keys: list of value keys to return - :returns results aggregated by :param column value - """ - if not value_keys: - value_keys = [] - if key == "id": - key = "_id" - aggregated_data = {} - if source_set: - source_set = list(set(source_set)) - items = cls.db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ALL, - document=document, - filters=MongoFilter( - key=key, - value=source_set, - logical_operator=MongoLogicalOperators.IN, - ), - ) - ) - for item in items: - items_key = item.pop(key, None) - if items_key: - aggregated_data.setdefault(items_key, []).append( - { - k: v - for k, v in item.items() - if k in value_keys or not value_keys - } - ) - return aggregated_data - - @classmethod - def get_conversation_data( - cls, - search_str: Union[list, str], - column_identifiers: List[str] = None, - limit: int = 1, - allow_regex_search: bool = False, - ) -> Union[None, dict]: - """ - Gets matching conversation data - :param search_str: search string to lookup - :param column_identifiers: desired column identifiers to look up - :param limit: limit found conversations - :param allow_regex_search: to allow search for matching entries that CONTAIN :param search_str - """ - if isinstance(search_str, str): - search_str = [search_str] - if not column_identifiers: - column_identifiers = ["_id", "conversation_name"] - or_expression = [] - for _keyword in [item for item in search_str if item is not None]: - for identifier in column_identifiers: - if identifier == "_id" and isinstance(_keyword, str): - try: - or_expression.append({identifier: ObjectId(_keyword)}) - except: - pass - if allow_regex_search: - if not _keyword: - expression = ".*" - else: - expression = f".*{_keyword}.*" - _keyword = re.compile(expression, re.IGNORECASE) - or_expression.append({identifier: _keyword}) - - conversations_data = list( - cls.db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.CHATS, - filters=MongoFilter( - value=or_expression, logical_operator=MongoLogicalOperators.OR - ), - result_filters={"limit": limit}, - ), - ) - ) - for conversation_data in conversations_data: - conversation_data["_id"] = str(conversation_data["_id"]) - if conversations_data and limit == 1: - conversations_data = conversations_data[0] - return conversations_data - - @classmethod - def fetch_shout_data( - cls, - conversation_data: dict, - start_idx: int = 0, - limit: int = 100, - fetch_senders: bool = True, - id_from: str = None, - shout_ids: List[str] = None, - ) -> List[dict]: - """ - Fetches shout data out of conversation data - - :param conversation_data: input conversation data - :param start_idx: message index to start from (sorted by recency) - :param limit: number of shouts to fetch - :param fetch_senders: to fetch shout senders data - :param id_from: message id to start from - :param shout_ids: list of shout ids to fetch - """ - if not shout_ids and conversation_data.get("chat_flow", None): - if id_from: - try: - start_idx = len(conversation_data["chat_flow"]) - conversation_data[ - "chat_flow" - ].index(id_from) - except ValueError: - LOG.warning("Matching start message id not found") - return [] - if start_idx == 0: - conversation_data["chat_flow"] = conversation_data["chat_flow"][ - start_idx - limit : - ] - else: - conversation_data["chat_flow"] = conversation_data["chat_flow"][ - -start_idx - limit : -start_idx - ] - shout_ids = [str(msg_id) for msg_id in conversation_data["chat_flow"]] - shouts_data = cls.fetch_shouts(shout_ids=shout_ids, fetch_senders=fetch_senders) - return sorted(shouts_data, key=lambda user_shout: int(user_shout["created_on"])) - - @classmethod - def fetch_users_from_prompt(cls, prompt: dict): - """Fetches user ids detected in provided prompt""" - prompt_data = prompt["data"] - user_ids = prompt_data.get("participating_subminds", []) - return cls.list_items( - document=MongoDocuments.USERS, - source_set=user_ids, - value_keys=["first_name", "last_name", "nickname", "is_bot", "avatar"], - ) - - @classmethod - def fetch_messages_from_prompt(cls, prompt: dict): - """Fetches message ids detected in provided prompt""" - prompt_data = prompt["data"] - message_ids = [] - for column in ( - "proposed_responses", - "submind_opinions", - "votes", - ): - message_ids.extend(list(prompt_data.get(column, {}).values())) - return cls.list_items(document=MongoDocuments.SHOUTS, source_set=message_ids) - - @classmethod - def fetch_prompt_data( - cls, - cid: str, - limit: int = 100, - id_from: str = None, - prompt_ids: List[str] = None, - fetch_user_data: bool = False, - created_from: int = None, - ) -> List[dict]: - """ - Fetches prompt data out of conversation data - - :param cid: target conversation id - :param limit: number of prompts to fetch - :param id_from: prompt id to start from - :param prompt_ids: prompt ids to fetch - :param fetch_user_data: to fetch user data in the - :param created_from: timestamp to filter messages from - - :returns list of matching prompt data along with matching messages and users - """ - filters = [MongoFilter("cid", cid)] - if id_from: - checkpoint_prompt = cls.db_controller.exec_query( - MongoQuery( - document=MongoDocuments.PROMPTS, - command=MongoCommands.FIND_ONE, - filters=MongoFilter("_id", id_from), - ) - ) - if checkpoint_prompt: - filters.append( - MongoFilter( - "created_on", - checkpoint_prompt["created_on"], - MongoLogicalOperators.LT, - ) - ) - if prompt_ids: - if isinstance(prompt_ids, str): - prompt_ids = [prompt_ids] - filters.append(MongoFilter("_id", prompt_ids, MongoLogicalOperators.IN)) - if created_from: - filters.append( - MongoFilter("created_on", created_from, MongoLogicalOperators.GT) - ) - matching_prompts = cls.db_controller.exec_query( - query=MongoQuery( - document=MongoDocuments.PROMPTS, - command=MongoCommands.FIND_ALL, - filters=filters, - result_filters={ - "sort": [("created_on", pymongo.DESCENDING)], - "limit": limit, - }, - ), - as_cursor=False, - ) - for prompt in matching_prompts: - prompt["user_mapping"] = cls.fetch_users_from_prompt(prompt) - prompt["message_mapping"] = cls.fetch_messages_from_prompt(prompt) - if fetch_user_data: - for user in prompt.get("data", {}).get("participating_subminds", []): - try: - nick = prompt["user_mapping"][user][0]["nickname"] - except KeyError: - LOG.warning( - f'user_id - "{user}" was not detected setting it as nick' - ) - nick = user - for k in ( - "proposed_responses", - "submind_opinions", - "votes", - ): - msg_id = prompt["data"][k].pop(user, "") - if msg_id: - prompt["data"][k][nick] = ( - prompt["message_mapping"] - .get(msg_id, [{}])[0] - .get("message_text") - or msg_id - ) - prompt["data"]["participating_subminds"] = [ - prompt["user_mapping"][x][0]["nickname"] - for x in prompt["data"]["participating_subminds"] - ] - return sorted(matching_prompts, key=lambda _prompt: int(_prompt["created_on"])) - - @classmethod - def fetch_skin_message_data( - cls, - skin: ConversationSkins, - conversation_data: dict, - start_idx: int = 0, - limit: int = 100, - fetch_senders: bool = True, - start_message_id: str = None, - ): - """Fetches message data based on provided conversation skin""" - message_data = cls.fetch_shout_data( - conversation_data=conversation_data, - fetch_senders=fetch_senders, - start_idx=start_idx, - id_from=start_message_id, - limit=limit, - ) - for message in message_data: - message["message_type"] = "plain" - if skin == ConversationSkins.PROMPTS: - detected_prompts = list( - set( - item.get("prompt_id") - for item in message_data - if item.get("prompt_id") - ) - ) - prompt_data = cls.fetch_prompt_data( - cid=conversation_data["_id"], prompt_ids=detected_prompts - ) - if prompt_data: - detected_prompt_ids = [] - for prompt in prompt_data: - prompt["message_type"] = "prompt" - detected_prompt_ids.append(prompt["_id"]) - message_data = [ - message - for message in message_data - if message.get("prompt_id") not in detected_prompt_ids - ] - message_data.extend(prompt_data) - return sorted(message_data, key=lambda shout: int(shout["created_on"])) - - @classmethod - def fetch_shouts( - cls, shout_ids: List[str] = None, fetch_senders: bool = True - ) -> List[dict]: - """ - Fetches shout data from provided shouts list - :param shout_ids: list of shout ids to fetch - :param fetch_senders: to fetch shout senders data - - :returns Data from requested shout ids along with matching user data - """ - if not shout_ids: - return [] - shouts = cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter( - "_id", list(set(shout_ids)), MongoLogicalOperators.IN - ), - ), - as_cursor=False, - ) - result = list() - - if fetch_senders: - user_ids = list(set([shout["user_id"] for shout in shouts])) - - users_from_shouts = cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.USERS, - filters=MongoFilter("_id", user_ids, MongoLogicalOperators.IN), - ) - ) - - formatted_users = dict() - for users_from_shout in users_from_shouts: - user_id = users_from_shout.pop("_id", None) - formatted_users[user_id] = users_from_shout - - for shout in shouts: - matching_user = formatted_users.get(shout["user_id"], {}) - if not matching_user: - matching_user = create_from_pattern(UserPatterns.UNRECOGNIZED_USER) - - matching_user.pop("password", None) - matching_user.pop("is_tmp", None) - shout["message_id"] = shout["_id"] - shout_data = {**shout, **matching_user} - result.append(shout_data) - shouts = result - return shouts - - @classmethod - def get_translations(cls, translation_mapping: dict) -> Tuple[dict, dict]: - """ - Gets translation from db based on provided mapping - - :param translation_mapping: mapping of cid to desired translation language - - :return translations fetched from db - """ - populated_translations = {} - missing_translations = {} - for cid, cid_data in translation_mapping.items(): - lang = cid_data.get("lang", "en") - shout_ids = cid_data.get("shouts", []) - conversation_data = cls.get_conversation_data(search_str=cid) - if not conversation_data: - LOG.error(f"Failed to fetch conversation data - {cid}") - continue - shout_data = cls.fetch_shout_data( - conversation_data=conversation_data, - shout_ids=shout_ids, - fetch_senders=False, - ) - shout_lang = "en" - if len(shout_data) == 1: - shout_lang = shout_data[0].get("message_lang", "en") - for shout in shout_data: - message_text = shout.get("message_text") - if shout_lang != "en" and lang == "en": - shout_text = message_text - else: - shout_text = shout.get("translations", {}).get(lang) - if shout_text and lang != "en": - populated_translations.setdefault(cid, {}).setdefault("shouts", {})[ - shout["_id"] - ] = shout_text - elif message_text: - missing_translations.setdefault(cid, {}).setdefault("shouts", {})[ - shout["_id"] - ] = message_text - if missing_translations.get(cid): - missing_translations[cid]["lang"] = lang - missing_translations[cid]["source_lang"] = shout_lang - return populated_translations, missing_translations - - @classmethod - def save_translations(cls, translation_mapping: dict) -> Dict[str, List[str]]: - """ - Saves translations in DB - :param translation_mapping: mapping of cid to desired translation language - :returns dictionary containing updated shouts (those which were translated to English) - """ - updated_shouts = {} - for cid, shout_data in translation_mapping.items(): - translations = shout_data.get("shouts", {}) - bulk_update = [] - shouts = cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter( - "_id", list(translations), MongoLogicalOperators.IN - ), - ), - as_cursor=False, - ) - for shout_id, translation in translations.items(): - matching_instance = None - for shout in shouts: - if shout["_id"] == shout_id: - matching_instance = shout - break - if not matching_instance.get("translations"): - filter_expression = {"_id": shout_id} - cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.SHOUTS, - filters=filter_expression, - data={"translations": {}}, - data_action="set", - ) - ) - # English is the default language, so it is treated as message text - if shout_data.get("lang", "en") == "en": - updated_shouts.setdefault(cid, []).append(shout_id) - filter_expression = {"_id": shout_id} - update_expression = {"$set": {"message_lang": "en"}} - cls.db_controller.exec_query( - query={ - "document": "shouts", - "command": "update", - "data": ( - filter_expression, - update_expression, - ), - } - ) - bulk_update_setter = { - "message_text": translation, - "message_lang": "en", - } - else: - bulk_update_setter = { - f'translations.{shout_data["lang"]}': translation - } - # TODO: make a convenience wrapper to make bulk insertion easier to follow - bulk_update.append( - UpdateOne({"_id": shout_id}, {"$set": bulk_update_setter}) - ) - if len(bulk_update) > 0: - cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.BULK_WRITE, - document=MongoDocuments.SHOUTS, - data=bulk_update, - ) - ) - return updated_shouts - - @classmethod - def get_user_preferences(cls, user_id): - """Gets preferences of specified user, creates default if not exists""" - prefs = {"tts": {}, "chat_language_mapping": {}} - if user_id: - user = cls.get_user(user_id=user_id) or {} - if user and not user.get("preferences"): - cls.db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.USERS, - filters=MongoFilter(key="_id", value=user_id), - data={"preferences": prefs}, - data_action="set", - ) - ) - else: - prefs = user.get("preferences") - else: - LOG.warning("user_id is None") - return prefs - - @classmethod - def set_user_preferences(cls, user_id, preferences_mapping: dict): - """Sets user preferences for specified user according to preferences mapping""" - if user_id: - try: - update_mapping = { - f"preferences.{key}": val - for key, val in preferences_mapping.items() - } - cls.db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.USERS, - filters=MongoFilter("_id", user_id), - data=update_mapping, - data_action="set", - ) - ) - except Exception as ex: - LOG.error(f"Failed to update preferences for user_id={user_id} - {ex}") - - @classmethod - def save_tts_response( - cls, shout_id, audio_data: str, lang: str = "en", gender: str = "female" - ) -> bool: - """ - Saves TTS Response under corresponding shout id - - :param shout_id: message id to consider - :param audio_data: base64 encoded audio data received - :param lang: language of speech (defaults to English) - :param gender: language gender (defaults to female) - - :return bool if saving was successful - """ - from chat_server.server_config import sftp_connector - - audio_file_name = f"{shout_id}_{lang}_{gender}.wav" - try: - sftp_connector.put_file_object( - file_object=audio_data, save_to=f"audio/{audio_file_name}" - ) - cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.SHOUTS, - filters=MongoFilter("_id", shout_id), - data={f"audio.{lang}.{gender}": audio_file_name}, - data_action="set", - ) - ) - operation_success = True - except Exception as ex: - LOG.error(f"Failed to save TTS response to db - {ex}") - operation_success = False - return operation_success - - @classmethod - def save_stt_response(cls, shout_id, message_text: str, lang: str = "en"): - """ - Saves STT Response under corresponding shout id - - :param shout_id: message id to consider - :param message_text: STT result transcript - :param lang: language of speech (defaults to English) - """ - try: - cls.db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.SHOUTS, - filters=MongoFilter("_id", shout_id), - data={f"transcripts.{lang}": message_text}, - data_action="set", - ) - ) - except Exception as ex: - LOG.error(f"Failed to save STT response to db - {ex}") - - @classmethod - def fetch_audio_data_from_message(cls, message_id: str) -> str: - """ - Fetches audio data from message if any - :param message_id: message id to fetch - """ - shout_data = cls.fetch_shouts(shout_ids=[message_id]) - if not shout_data: - LOG.warning("Requested shout does not exist") - elif shout_data[0].get("is_audio") != "1": - LOG.warning("Failed to fetch audio data from non-audio message") - else: - from chat_server.server_config import sftp_connector - - file_location = f'audio/{shout_data[0]["message_text"]}' - LOG.info(f"Fetching existing file from: {file_location}") - fo = sftp_connector.get_file_object(file_location) - if fo.getbuffer().nbytes > 0: - return buffer_to_base64(fo) - else: - LOG.error( - f"Empty buffer received while fetching audio of message id = {message_id}" - ) - return "" diff --git a/chat_server/server_utils/enums.py b/chat_server/server_utils/enums.py index 21ec92e5..87ccddd9 100644 --- a/chat_server/server_utils/enums.py +++ b/chat_server/server_utils/enums.py @@ -26,7 +26,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from enum import Enum, IntEnum +from enum import Enum class DataSources(Enum): diff --git a/chat_server/server_utils/prompt_utils.py b/chat_server/server_utils/prompt_utils.py deleted file mode 100644 index 2bea8587..00000000 --- a/chat_server/server_utils/prompt_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2022 Neongecko.com Inc. -# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, -# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo -# BSD-3 License -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from enum import IntEnum - -from chat_server.server_config import db_controller -from utils.database_utils.mongo_utils import * -from utils.logging_utils import LOG - - -class PromptStates(IntEnum): - """Prompt States""" - - IDLE = 0 # No active prompt - RESP = 1 # Gathering responses to prompt - DISC = 2 # Discussing responses - VOTE = 3 # Voting on responses - PICK = 4 # Proctor will select response - WAIT = ( - 5 # Bot is waiting for the proctor to ask them to respond (not participating) - ) - - -def handle_prompt_message(message: dict) -> bool: - """ - Handles received prompt message - :param message: message dictionary received - :returns True if prompt message was handled, false otherwise - """ - try: - prompt_id = message.get("prompt_id") - prompt_state = PromptStates( - int(message.get("promptState", PromptStates.IDLE.value)) - ) - user_id = message["userID"] - message_id = message["messageID"] - ok = True - if prompt_id: - existing_prompt = ( - db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ONE, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key="_id", value=prompt_id), - ) - ) - or {} - ) - if existing_prompt and existing_prompt["is_completed"] == "0": - if user_id not in existing_prompt.get("data", {}).get( - "participating_subminds", [] - ): - data_kwargs = { - "data": {"data.participating_subminds": user_id}, - "data_action": "push", - } - db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key="_id", value=prompt_id), - **data_kwargs, - ) - ) - - prompt_state_mapping = { - # PromptStates.WAIT: {'key': 'participating_subminds', 'type': list}, - PromptStates.RESP: { - "key": f"proposed_responses.{user_id}", - "type": dict, - "data": message_id, - }, - PromptStates.DISC: { - "key": f"submind_opinions.{user_id}", - "type": dict, - "data": message_id, - }, - PromptStates.VOTE: { - "key": f"votes.{user_id}", - "type": dict, - "data": message_id, - }, - } - store_key_properties = prompt_state_mapping.get(prompt_state) - if not store_key_properties: - LOG.warning( - f"Prompt State - {prompt_state.name} has no db store properties" - ) - else: - store_key = store_key_properties["key"] - store_type = store_key_properties["type"] - store_data = store_key_properties["data"] - if user_id in list( - existing_prompt.get("data", {}).get(store_key, {}) - ): - LOG.error( - f"user_id={user_id} tried to duplicate data to prompt_id={prompt_id}, store_key={store_key}" - ) - else: - data_kwargs = { - "data": {f"data.{store_key}": store_data}, - "data_action": "push" if store_type == list else "set", - } - db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key="_id", value=prompt_id), - **data_kwargs, - ) - ) - else: - ok = False - except Exception as ex: - LOG.error(f"Failed to handle prompt message - {message} ({ex})") - ok = False - return ok diff --git a/chat_server/server_utils/user_utils.py b/chat_server/server_utils/user_utils.py index b3fef4c3..78361418 100644 --- a/chat_server/server_utils/user_utils.py +++ b/chat_server/server_utils/user_utils.py @@ -58,87 +58,3 @@ def create_from_pattern(source: UserPatterns, override_defaults: dict = None) -> matching_data.setdefault("is_tmp", True) return matching_data - - -def get_neon_data(db_controller: DatabaseController, skill_name: str = "neon") -> dict: - """ - Gets a user profile for the user 'Neon' and adds it to the users db if not already present - - :param db_controller: db controller instance - :param skill_name: Neon Skill to consider (defaults to neon - Neon Assistant) - - :return Neon AI data - """ - neon_data = db_controller.exec_query( - {"command": "find_one", "document": "users", "data": {"nickname": skill_name}} - ) - if not neon_data: - last_name = "AI" if skill_name == "neon" else skill_name.capitalize() - nickname = skill_name - neon_data = create_from_pattern( - source=UserPatterns.NEON, - override_defaults={"last_name": last_name, "nickname": nickname}, - ) - db_controller.exec_query( - MongoQuery( - command=MongoCommands.INSERT_ONE, - document=MongoDocuments.USERS, - data=neon_data, - ) - ) - return neon_data - - -def get_bot_data( - db_controller: DatabaseController, nickname: str, context: dict = None -) -> dict: - """ - Gets a user profile for the requested bot instance and adds it to the users db if not already present - - :param db_controller: db controller instance - :param nickname: nickname of the bot provided - :param context: context with additional bot information (optional) - - :return Matching bot data - """ - if not context: - context = {} - full_nickname = nickname - nickname = nickname.split("-")[0] - bot_data = db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ONE, - document=MongoDocuments.USERS, - filters=MongoFilter(key="nickname", value=nickname), - ) - ) - if not bot_data: - bot_data = dict( - _id=generate_uuid(length=20), - first_name=context.get("first_name", nickname.capitalize()), - last_name=context.get("last_name", ""), - avatar=context.get("avatar", ""), - password=get_hash(generate_uuid()), - nickname=nickname, - is_bot="1", - full_nickname=full_nickname, - date_created=int(time()), - is_tmp=False, - ) - db_controller.exec_query( - MongoQuery( - command=MongoCommands.INSERT_ONE, - document=MongoDocuments.USERS, - data=bot_data, - ) - ) - elif not bot_data.get("is_bot") == "1": - db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.USERS, - filters=MongoFilter("_id", bot_data["_id"]), - data={"is_bot": "1"}, - ) - ) - return bot_data diff --git a/chat_server/services/popularity_counter.py b/chat_server/services/popularity_counter.py index 925439f3..7d8b8585 100644 --- a/chat_server/services/popularity_counter.py +++ b/chat_server/services/popularity_counter.py @@ -31,12 +31,10 @@ from utils.database_utils.mongo_utils import ( - MongoQuery, - MongoCommands, - MongoDocuments, MongoFilter, MongoLogicalOperators, ) +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG @@ -79,27 +77,16 @@ def init_data(cls, actuality_days: int = 7): :param actuality_days: number of days for message to affect the chat popularity """ - from chat_server.server_utils.db_utils import DbUtils - curr_time = int(time()) - chats = DbUtils.db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.CHATS, - filters=MongoFilter(key="is_private", value=False), - ), - as_cursor=False, - ) - relevant_shouts = DbUtils.db_controller.exec_query( - MongoQuery( - command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter( + chats = MongoDocumentsAPI.CHATS.list_items(include_private=False) + relevant_shouts = MongoDocumentsAPI.SHOUTS.list_items( + filters=[ + MongoFilter( key="created_on", logical_operator=MongoLogicalOperators.GTE, value=curr_time - 3600 * 24 * actuality_days, - ), - ) + ) + ] ) relevant_shouts = set(x["_id"] for x in relevant_shouts) formatted_chats = [] diff --git a/chat_server/sio.py b/chat_server/sio.py index 42d55504..efc7f6c1 100644 --- a/chat_server/sio.py +++ b/chat_server/sio.py @@ -26,7 +26,6 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import json import os import socketio @@ -35,22 +34,16 @@ from typing import List, Optional from cachetools import LRUCache + +from utils.database_utils.mongo_utils.queries import mongo_queries +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG from utils.common import generate_uuid, deep_merge, buffer_to_base64 from chat_server.server_utils.auth import validate_session from chat_server.server_utils.cache_utils import CacheFactory -from chat_server.server_utils.db_utils import ( - DbUtils, - MongoCommands, - MongoDocuments, - MongoQuery, - MongoFilter, -) -from chat_server.server_utils.prompt_utils import handle_prompt_message -from chat_server.server_utils.user_utils import get_neon_data, get_bot_data from chat_server.server_utils.languages import LanguageSettings -from chat_server.server_config import db_controller, sftp_connector +from chat_server.server_config import sftp_connector from chat_server.services.popularity_counter import PopularityCounter sio = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi") @@ -154,7 +147,6 @@ async def user_message(sid, data): ``` data = {'cid':'conversation id', 'userID': 'emitted user id', - 'messageID': 'id of emitted message', 'promptID': 'id of related prompt (optional)', 'source': 'declared name of the source that shouted given user message' 'messageText': 'content of the user message', @@ -172,9 +164,9 @@ async def user_message(sid, data): """ LOG.debug(f"Got new user message from {sid}: {data}") try: - filter_expression = dict(_id=data["cid"]) - cid_data = DbUtils.get_conversation_data( - data["cid"], column_identifiers=["_id"] + cid_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=data["cid"], + column_identifiers=["_id"], ) if not cid_data: msg = "Shouting to non-existent conversation, skipping further processing" @@ -182,26 +174,14 @@ async def user_message(sid, data): return LOG.info(f"Received user message data: {data}") - data["messageID"] = data.get("messageID") - if data["messageID"]: - existing_shout = DbUtils.fetch_shouts( - shout_ids=[data["messageID"]], fetch_senders=False - ) - if existing_shout: - raise ValueError( - f'messageID value="{data["messageID"]}" already exists' - ) - else: - data["messageID"] = generate_uuid() + data["message_id"] = generate_uuid() data["is_bot"] = data.pop("bot", "0") if data["userID"].startswith("neon"): - neon_data = get_neon_data(db_controller=db_controller) + neon_data = MongoDocumentsAPI.USERS.get_neon_data(skill_name="neon") data["userID"] = neon_data["_id"] elif data["is_bot"] == "1": - bot_data = get_bot_data( - db_controller=db_controller, - nickname=data["userID"], - context=data.get("context", None), + bot_data = MongoDocumentsAPI.USERS.get_bot_data( + nickname=data["userID"], context=data.get("context") ) data["userID"] = bot_data["_id"] @@ -210,7 +190,7 @@ async def user_message(sid, data): if is_audio != "1": is_audio = "0" - audio_path = f'{data["messageID"]}_audio.wav' + audio_path = f'{data["message_id"]}_audio.wav' try: if is_audio == "1": message_text = data["messageText"].split(",")[-1] @@ -232,7 +212,7 @@ async def user_message(sid, data): data["prompt_id"] = data.pop("promptID", "") new_shout_data = { - "_id": data["messageID"], + "_id": data["message_id"], "cid": data["cid"], "user_id": data["userID"], "prompt_id": data["prompt_id"], @@ -252,24 +232,9 @@ async def user_message(sid, data): if lang != "en": new_shout_data["translations"][lang] = data["messageText"] - db_controller.exec_query( - MongoQuery( - command=MongoCommands.INSERT_ONE, - document=MongoDocuments.SHOUTS, - data=new_shout_data, - ) - ) - db_controller.exec_query( - query=MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.CHATS, - filters=filter_expression, - data={"chat_flow": new_shout_data["_id"]}, - data_action="push", - ) - ) + mongo_queries.add_shout(data=new_shout_data) if is_announcement == "0" and data["prompt_id"]: - is_ok = handle_prompt_message(data) + is_ok = MongoDocumentsAPI.PROMPTS.add_shout_to_prompt(data) if is_ok: await sio.emit( "new_prompt_message", @@ -285,12 +250,9 @@ async def user_message(sid, data): message_tts = data.get("messageTTS", {}) for language, gender_mapping in message_tts.items(): for gender, audio_data in gender_mapping.items(): - sftp_connector.put_file_object( - file_object=audio_data, save_to=f"audio/{audio_path}" - ) - DbUtils.save_tts_response( - shout_id=data["messageID"], - audio_file_name=audio_path, + MongoDocumentsAPI.SHOUTS.save_tts_response( + shout_id=data["message_id"], + audio_data=audio_data, lang=language, gender=gender, ) @@ -334,13 +296,7 @@ async def new_prompt(sid, data): "data": {"prompt_text": prompt_text}, "created_on": created_on, } - db_controller.exec_query( - MongoQuery( - command=MongoCommands.INSERT_ONE, - document=MongoDocuments.PROMPTS, - data=formatted_data, - ) - ) + MongoDocumentsAPI.PROMPTS.add_item(data=formatted_data) await sio.emit("new_prompt_created", data=formatted_data) except Exception as ex: LOG.error(f'Prompt "{prompt_id}" was not created due to exception - {ex}') @@ -355,28 +311,15 @@ async def prompt_completed(sid, data): :param data: user message data """ prompt_id = data["context"]["prompt"]["prompt_id"] - prompt_summary_keys = ["winner", "votes_per_submind"] - prompt_summary_agg = { - f"data.{k}": v for k, v in data["context"].items() if k in prompt_summary_keys + + MongoDocumentsAPI.PROMPTS.set_completed( + prompt_id=prompt_id, prompt_context=data["context"] + ) + formatted_data = { + "winner": data["context"].get("winner", ""), + "prompt_id": prompt_id, } - prompt_summary_agg["is_completed"] = "1" - try: - db_controller.exec_query( - MongoQuery( - command=MongoCommands.UPDATE_MANY, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key="_id", value=prompt_id), - data=prompt_summary_agg, - data_action="set", - ) - ) - formatted_data = { - "winner": data["context"].get("winner", ""), - "prompt_id": prompt_id, - } - await sio.emit("set_prompt_completed", data=formatted_data) - except Exception as ex: - LOG.error(f'Prompt "{prompt_id}" was not updated due to exception - {ex}') + await sio.emit("set_prompt_completed", data=formatted_data) @sio.event @@ -394,7 +337,7 @@ async def get_prompt_data(sid, data): ``` """ prompt_id = data.get("prompt_id") - _prompt_data = DbUtils.fetch_prompt_data( + _prompt_data = mongo_queries.fetch_prompt_data( cid=data["cid"], limit=data.get("limit", 5), prompt_ids=[prompt_id], @@ -439,7 +382,7 @@ async def request_translate(sid, data): else: input_type = data.get("inputType", "incoming") - populated_translations, missing_translations = DbUtils.get_translations( + populated_translations, missing_translations = mongo_queries.get_translations( translation_mapping=data.get("chat_mapping", {}) ) if populated_translations and not missing_translations: @@ -492,7 +435,9 @@ async def get_neon_translations(sid, data): return sid = cached_data.get("sid") input_type = cached_data.get("input_type") - updated_shouts = DbUtils.save_translations(data.get("translations", {})) + updated_shouts = MongoDocumentsAPI.SHOUTS.save_translations( + translation_mapping=data.get("translations", {}) + ) populated_translations = deep_merge( data.get("translations", {}), cached_data.get("translations", {}) ) @@ -542,25 +487,23 @@ async def request_tts(sid, data): message_id = data["message_id"] user_id = data["user_id"] cid = data["cid"] - matching_messages = DbUtils.fetch_shouts( - shout_ids=[message_id], fetch_senders=False - ) - if not matching_messages: + matching_message = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_message: LOG.error("Failed to request TTS - matching message not found") else: - matching_message = matching_messages[0] - + # TODO: support for multiple genders in TTS # Trying to get existing audio data - preferred_gender = ( - DbUtils.get_user_preferences(user_id=user_id) - .get("tts", {}) - .get(lang, {}) - .get("gender", "female") - ) - existing_audio_file = ( + # preferred_gender = ( + # MongoDocumentsAPI.USERS.get_preferences(user_id=user_id) + # .get("tts", {}) + # .get(lang, {}) + # .get("gender", "female") + # ) + preferred_gender = "female" + audio_file = ( matching_message.get("audio", {}).get(lang, {}).get(preferred_gender) ) - if not existing_audio_file: + if not audio_file: LOG.info( f"File was not detected for cid={cid}, message_id={message_id}, lang={lang}" ) @@ -575,7 +518,7 @@ async def request_tts(sid, data): await sio.emit("get_tts", data=formatted_data) else: try: - file_location = f"audio/{existing_audio_file}" + file_location = f"audio/{audio_file}" LOG.info(f"Fetching existing file from: {file_location}") fo = sftp_connector.get_file_object(file_location) if fo.getbuffer().nbytes > 0: @@ -608,8 +551,8 @@ async def tts_response(sid, data): sid = mq_context.get("sid") lang = LanguageSettings.to_system_lang(data.get("lang", "en-us")) lang_gender = data.get("gender", "undefined") - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if not matching_shouts: + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_shout: LOG.warning( f"Skipping TTS Response for message_id={message_id} - matching shout does not exist" ) @@ -620,7 +563,7 @@ async def tts_response(sid, data): f"Skipping TTS Response for message_id={message_id} - audio data is empty" ) else: - is_ok = DbUtils.save_tts_response( + is_ok = MongoDocumentsAPI.SHOUTS.save_tts_response( shout_id=message_id, audio_data=audio_data, lang=lang, @@ -651,8 +594,8 @@ async def stt_response(sid, data): """Handle STT Response from Observer""" mq_context = data.get("context", {}) message_id = mq_context.get("message_id") - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if not matching_shouts: + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_shout: LOG.warning( f"Skipping STT Response for message_id={message_id} - matching shout does not exist" ) @@ -660,7 +603,7 @@ async def stt_response(sid, data): try: message_text = data.get("transcript") lang = LanguageSettings.to_system_lang(data["lang"]) - DbUtils.save_stt_response( + MongoDocumentsAPI.SHOUTS.save_stt_response( shout_id=message_id, message_text=message_text, lang=lang ) sid = mq_context.get("sid") @@ -703,20 +646,23 @@ async def request_stt(sid, data): # TODO: process received language lang = "en" # lang = data.get('lang', 'en') - existing_shouts = DbUtils.fetch_shouts(shout_ids=[message_id]) - if existing_shouts: - existing_transcript = existing_shouts[0].get("transcripts", {}).get(lang) - if existing_transcript: + if shout_data := MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id): + message_transcript = shout_data.get("transcripts", {}).get(lang) + if message_transcript: response_data = { "cid": cid, "message_id": message_id, "lang": lang, - "message_text": existing_transcript, + "message_text": message_transcript, } return await sio.emit("incoming_stt", data=response_data, to=sid) - audio_data = data.get("audio_data") or DbUtils.fetch_audio_data_from_message( - message_id - ) + else: + err_msg = "Message transcript was missing" + LOG.error(err_msg) + return await emit_error(message=err_msg, sids=[sid]) + audio_data = data.get( + "audio_data" + ) or MongoDocumentsAPI.SHOUTS.fetch_audio_data(message_id=message_id) if not audio_data: LOG.error("Failed to fetch audio data") else: diff --git a/chat_server/tests/test_sio.py b/chat_server/tests/test_sio.py index c065766b..d0a0ca25 100644 --- a/chat_server/tests/test_sio.py +++ b/chat_server/tests/test_sio.py @@ -37,9 +37,9 @@ from chat_server.constants.users import ChatPatterns from chat_server.tests.beans.server import ASGITestServer -from chat_server.server_utils.auth import generate_uuid from chat_server.server_config import db_controller from utils.logging_utils import LOG +from utils.common import generate_uuid SERVER_ADDRESS = "http://127.0.0.1:8888" TEST_CID = "-1" @@ -69,6 +69,7 @@ class TestSIO(unittest.TestCase): @classmethod def setUpClass(cls) -> None: from chat_server.server_config import database_config_path + assert os.path.isfile(database_config_path) os.environ["DISABLE_AUTH_CHECK"] = "1" matching_conversation = db_controller.exec_query( diff --git a/config.py b/config.py index b4bd6b7d..4ce20210 100644 --- a/config.py +++ b/config.py @@ -143,8 +143,6 @@ def get_db_controller( :returns instance of Database Controller """ - from chat_server.server_utils.db_utils import DbUtils - db_controller = self.db_controllers.get(name, None) if not db_controller or override: db_config = self.get_db_config_from_key(key=name) @@ -160,5 +158,4 @@ def get_db_controller( db_controller = DatabaseController(config_data=db_config) db_controller.attach_connector(dialect=dialect) db_controller.connect() - DbUtils.init(db_controller) return db_controller diff --git a/migration_scripts/shouts.py b/migration_scripts/shouts.py index 245b7717..275a9f52 100644 --- a/migration_scripts/shouts.py +++ b/migration_scripts/shouts.py @@ -30,13 +30,13 @@ from pymongo import ReplaceOne, UpdateOne from chat_server.server_utils.db_utils import ( - DbUtils, MongoQuery, MongoCommands, MongoDocuments, ) from migration_scripts.utils.shout_utils import prepare_nicks_for_sql from migration_scripts.utils.sql_utils import iterable_to_sql_array, sql_arr_is_null +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG @@ -137,7 +137,9 @@ def remap_creation_timestamp(db_controller): filter_stage = {"$match": {"created_on": {"$gte": 10**12}}} bulk_update = [] res = list( - DbUtils.db_controller.connector.connection["shouts"].aggregate([filter_stage]) + MongoDocumentsAPI.db_controller.connector.connection["shouts"].aggregate( + [filter_stage] + ) ) for item in res: bulk_update.append( @@ -171,7 +173,7 @@ def set_cid_to_shouts(db_controller): bulk_update.append( UpdateOne({"_id": shout}, {"$set": {"cid": item["_id"]}}) ) - DbUtils.db_controller.exec_query( + MongoDocumentsAPI.db_controller.exec_query( query=MongoQuery( command=MongoCommands.BULK_WRITE, document=MongoDocuments.SHOUTS, diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/database_utils/mongo_utils/queries/__init__.py b/utils/database_utils/mongo_utils/queries/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/database_utils/mongo_utils/queries/constants.py b/utils/database_utils/mongo_utils/queries/constants.py new file mode 100644 index 00000000..0a41108c --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/constants.py @@ -0,0 +1,26 @@ +from enum import Enum + + +class UserPatterns(Enum): + """Collection of user patterns used for commonly in conversations""" + + UNRECOGNIZED_USER = { + "first_name": "Deleted", + "last_name": "User", + "nickname": "deleted_user", + } + GUEST = {"first_name": "Klat", "last_name": "Guest"} + NEON = { + "first_name": "Neon", + "last_name": "AI", + "nickname": "neon", + "avatar": "neon.webp", + } + GUEST_NANO = {"first_name": "Nano", "last_name": "Guest", "tokens": []} + + +class ConversationSkins: + """List of supported conversation skins""" + + BASE = "base" + PROMPTS = "prompts" diff --git a/utils/database_utils/mongo_utils/queries/dao/__init__.py b/utils/database_utils/mongo_utils/queries/dao/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/database_utils/mongo_utils/queries/dao/abc.py b/utils/database_utils/mongo_utils/queries/dao/abc.py new file mode 100644 index 00000000..c432ec6e --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/abc.py @@ -0,0 +1,176 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from abc import ABC, abstractmethod + +from neon_sftp import NeonSFTPConnector + +from utils.database_utils import DatabaseController +from utils.database_utils.mongo_utils import ( + MongoQuery, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) + + +class MongoDocumentDAO(ABC): + def __init__( + self, + db_controller: DatabaseController, + sftp_connector: NeonSFTPConnector = None, + ): + self.db_controller = db_controller + self.sftp_connector = sftp_connector + + @property + @abstractmethod + def document(self): + pass + + def list_contains( + self, + key: str = "_id", + source_set: list = None, + aggregate_result: bool = True, + *args, + **kwargs + ) -> dict: + items = {} + contains_filter = self._build_contains_filter(key=key, lookup_set=source_set) + if contains_filter: + filters = kwargs.pop("filters", []) + [contains_filter] + items = self.list_items(filters=filters, *args, **kwargs) + if aggregate_result: + items = self.aggregate_items_by_key(key=key, items=items) + return items + + def list_items( + self, + filters: list[MongoFilter] = None, + projection_attributes: list = None, + limit: int = None, + result_as_cursor: bool = True, + ) -> dict: + """ + Lists items under provided document belonging to source set of provided column values + + :param filters: filters to consider (optional) + :param projection_attributes: list of value keys to return (optional) + :param limit: limit number of returned attributes (optional) + :param result_as_cursor: to return result as cursor (defaults to True) + :returns results of FIND operation over the desired document according to applied filters + """ + result_filters = {} + if limit: + result_filters["limit"] = limit + items = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=filters, + result_filters=result_filters, + result_as_cursor=result_as_cursor, + ) + # TODO: pymongo support projection only as aggregation API which is not yet implemented in project + if projection_attributes: + items = [ + {k: v} + for item in items + for k, v in item.items() + if k in projection_attributes + ] + return items + + def aggregate_items_by_key(self, key, items: list) -> dict: + aggregated_data = {} + # TODO: consider Mongo DB aggregation API + for item in items: + items_key = item.pop(key, None) + if items_key: + aggregated_data.setdefault(items_key, []).append(item) + return aggregated_data + + def _build_list_items_filter( + self, key, lookup_set, additional_filters: list[MongoFilter] + ) -> list[MongoFilter] | None: + mongo_filters = additional_filters or [] + contains_filter = self._build_contains_filter(key=key, lookup_set=lookup_set) + if contains_filter: + mongo_filters.append(contains_filter) + return mongo_filters + + def _build_contains_filter(self, key, lookup_set) -> MongoFilter | None: + mongo_filter = None + if key and lookup_set: + lookup_set = list(set(lookup_set)) + mongo_filter = MongoFilter( + key=key, + value=lookup_set, + logical_operator=MongoLogicalOperators.IN, + ) + return mongo_filter + + def add_item(self, data: dict): + return self._execute_query(command=MongoCommands.INSERT_ONE, data=data) + + def get_item( + self, item_id: str = None, filters: list[dict | MongoFilter] = None + ) -> dict | None: + if not filters: + filters = [] + if item_id: + if not isinstance(filters, list): + filters = [filters] + filters.append(MongoFilter(key="_id", value=item_id)) + if not filters: + return + return self._execute_query(command=MongoCommands.FIND_ONE, filters=filters) + + def _execute_query( + self, + command: MongoCommands, + filters: list[MongoFilter] = None, + data: dict = None, + data_action: str = "set", + result_filters: dict = None, + result_as_cursor: bool = True, + *args, + **kwargs + ): + return self.db_controller.exec_query( + MongoQuery( + command=command, + document=self.document, + filters=filters, + data=data, + data_action=data_action, + result_filters=result_filters, + ), + as_cursor=result_as_cursor, + *args, + **kwargs + ) diff --git a/utils/database_utils/mongo_utils/queries/dao/chats.py b/utils/database_utils/mongo_utils/queries/dao/chats.py new file mode 100644 index 00000000..53f02295 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/chats.py @@ -0,0 +1,128 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +from typing import Union, List + +from bson import ObjectId + +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.logging_utils import LOG + + +class ChatsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.CHATS + + def get_conversation_data( + self, + search_str: Union[list, str], + column_identifiers: List[str] = None, + limit: int = 1, + allow_regex_search: bool = False, + projection_attributes: dict = None, + include_private: bool = False, + ) -> Union[None, dict]: + """ + Gets matching conversation data + :param search_str: search string to lookup + :param column_identifiers: desired column identifiers to look up + :param limit: limit found conversations + :param allow_regex_search: to allow search for matching entries that CONTAIN :param search_str + :param projection_attributes: mapping of attributes to project (optional) + :param include_private: to include private conversations (defaults to False) + """ + if isinstance(search_str, str): + search_str = [search_str] + if not column_identifiers: + column_identifiers = ["_id", "conversation_name"] + or_expression = [] + for _keyword in [item for item in search_str if item is not None]: + for identifier in column_identifiers: + if identifier == "_id" and isinstance(_keyword, str): + try: + or_expression.append({identifier: ObjectId(_keyword)}) + except Exception as ex: + LOG.debug(f"Failed to add {_keyword = }| {ex = }") + if allow_regex_search: + if not _keyword: + expression = ".*" + else: + expression = f".*{_keyword}.*" + _keyword = re.compile(expression, re.IGNORECASE) + or_expression.append({identifier: _keyword}) + + chats = self.list_items( + filters=[ + MongoFilter( + value=or_expression, logical_operator=MongoLogicalOperators.OR + ) + ], + projection_attributes=projection_attributes, + limit=limit, + result_as_cursor=False, + include_private=include_private, + ) + for chat in chats: + chat["_id"] = str(chat["_id"]) + if chats and limit == 1: + chats = chats[0] + return chats + + def add_shout(self, cid: str, shout_id: str): + return self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=cid), + data={"chat_flow": shout_id}, + data_action="push", + ) + + def list_items( + self, + filters: list[MongoFilter] = None, + projection_attributes: list = None, + limit: int = None, + result_as_cursor: bool = True, + include_private: bool = False, + ) -> dict: + filters = filters or [] + if not include_private: + filters.append(MongoFilter(key="is_private", value=False)) + return super().list_items( + filters=filters, + projection_attributes=projection_attributes, + limit=limit, + result_as_cursor=result_as_cursor, + ) diff --git a/utils/database_utils/mongo_utils/queries/dao/prompts.py b/utils/database_utils/mongo_utils/queries/dao/prompts.py new file mode 100644 index 00000000..c09731f3 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/prompts.py @@ -0,0 +1,191 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from enum import IntEnum +from typing import List + +import pymongo + +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.logging_utils import LOG + + +class PromptStates(IntEnum): + """Prompt States""" + + IDLE = 0 # No active prompt + RESP = 1 # Gathering responses to prompt + DISC = 2 # Discussing responses + VOTE = 3 # Voting on responses + PICK = 4 # Proctor will select response + WAIT = ( + 5 # Bot is waiting for the proctor to ask them to respond (not participating) + ) + + +class PromptsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.PROMPTS + + def set_completed(self, prompt_id: str, prompt_context: dict): + prompt_summary_keys = ["winner", "votes_per_submind"] + prompt_summary_agg = { + f"data.{k}": v + for k, v in prompt_context.items() + if k in prompt_summary_keys + } + prompt_summary_agg["is_completed"] = "1" + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data=prompt_summary_agg, + ) + + def get_prompts( + self, + cid: str, + limit: int = 100, + id_from: str = None, + prompt_ids: List[str] = None, + created_from: int = None, + ) -> List[dict]: + """ + Fetches prompt data out of conversation data + + :param cid: target conversation id + :param limit: number of prompts to fetch + :param id_from: prompt id to start from + :param prompt_ids: prompt ids to fetch + :param fetch_user_data: to fetch user data in the + :param created_from: timestamp to filter messages from + + :returns list of matching prompt data along with matching messages and users + """ + filters = [MongoFilter("cid", cid)] + if id_from: + checkpoint_prompt = self._execute_query( + command=MongoCommands.FIND_ONE, + filters=MongoFilter("_id", id_from), + ) + if checkpoint_prompt: + filters.append( + MongoFilter( + "created_on", + checkpoint_prompt["created_on"], + MongoLogicalOperators.LT, + ) + ) + if prompt_ids: + if isinstance(prompt_ids, str): + prompt_ids = [prompt_ids] + filters.append(MongoFilter("_id", prompt_ids, MongoLogicalOperators.IN)) + if created_from: + filters.append( + MongoFilter("created_on", created_from, MongoLogicalOperators.GT) + ) + matching_prompts = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=filters, + result_filters={ + "sort": [("created_on", pymongo.DESCENDING)], + "limit": limit, + }, + result_as_cursor=False, + ) + return matching_prompts + + def add_shout_to_prompt( + self, prompt_id: str, user_id: str, message_id: str, prompt_state: PromptStates + ): + prompt = self.get_item(item_id=prompt_id) + if prompt and prompt["is_completed"] == "0": + if ( + user_id not in prompt.get("data", {}).get("participating_subminds", []) + and prompt_state == PromptStates.RESP + ): + self._add_participant(prompt_id=prompt_id, user_id=user_id) + prompt_state_structure = self._get_prompt_state_structure( + prompt_state=prompt_state, user_id=user_id, message_id=message_id + ) + if not prompt_state_structure: + LOG.warning( + f"Prompt State - {prompt_state.name} has no db store properties" + ) + else: + store_key = prompt_state_structure["key"] + store_type = prompt_state_structure["type"] + store_data = prompt_state_structure["data"] + if user_id in list(prompt.get("data", {}).get(store_key, {})): + LOG.error( + f"user_id={user_id} tried to duplicate data to prompt_id={prompt_id}, store_key={store_key}" + ) + else: + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data={f"data.{store_key}": store_data}, + data_action="push" if store_type == list else "set", + ) + + def _add_participant(self, prompt_id: str, user_id: str): + return self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data={"data.participating_subminds": user_id}, + data_action="push", + ) + + @staticmethod + def _get_prompt_state_structure( + prompt_state: PromptStates, user_id: str, message_id: str + ): + prompt_state_mapping = { + # PromptStates.WAIT: {'key': 'participating_subminds', 'type': list}, + PromptStates.RESP: { + "key": f"proposed_responses.{user_id}", + "type": dict, + "data": message_id, + }, + PromptStates.DISC: { + "key": f"submind_opinions.{user_id}", + "type": dict, + "data": message_id, + }, + PromptStates.VOTE: { + "key": f"votes.{user_id}", + "type": dict, + "data": message_id, + }, + } + return prompt_state_mapping.get(prompt_state) diff --git a/utils/database_utils/mongo_utils/queries/dao/shouts.py b/utils/database_utils/mongo_utils/queries/dao/shouts.py new file mode 100644 index 00000000..490c674a --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/shouts.py @@ -0,0 +1,197 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import List, Dict + +from ovos_utils import LOG +from pymongo import UpdateOne + +from utils.common import buffer_to_base64 +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, + MongoQuery, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO + + +class ShoutsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.SHOUTS + + def fetch_shouts(self, shout_ids: List[str] = None) -> List[dict]: + """ + Fetches shout data from provided shouts list + :param shout_ids: list of shout ids to fetch + + :returns Data from requested shout ids along with matching user data + """ + return self.list_contains( + source_set=shout_ids, aggregate_result=False, result_as_cursor=False + ) + + def fetch_messages_from_prompt(self, prompt: dict): + """Fetches message ids detected in provided prompt""" + prompt_data = prompt["data"] + message_ids = [] + for column in ( + "proposed_responses", + "submind_opinions", + "votes", + ): + message_ids.extend(list(prompt_data.get(column, {}).values())) + return self.list_contains(source_set=message_ids) + + def fetch_audio_data(self, message_id: str): + """ + Fetches audio data from message + :param message_id: message id to fetch + """ + shout_data = self.get_item(item_id=message_id) + if not shout_data: + LOG.warning("Requested shout does not exist") + elif shout_data.get("is_audio") != "1": + LOG.warning("Failed to fetch audio data from non-audio message") + else: + + file_location = f'audio/{shout_data["message_text"]}' + LOG.info(f"Fetching existing file from: {file_location}") + fo = self.sftp_connector.get_file_object(file_location) + if fo.getbuffer().nbytes > 0: + return buffer_to_base64(fo) + else: + LOG.error( + f"Empty buffer received while fetching audio of message id = {message_id}" + ) + return "" + + def save_translations(self, translation_mapping: dict) -> Dict[str, List[str]]: + """ + Saves translations in DB + :param translation_mapping: mapping of cid to desired translation language + :returns dictionary containing updated shouts (those which were translated to English) + """ + updated_shouts = {} + for cid, shout_data in translation_mapping.items(): + translations = shout_data.get("shouts", {}) + bulk_update = [] + shouts = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=MongoFilter( + "_id", list(translations), MongoLogicalOperators.IN + ), + result_as_cursor=False, + ) + for shout_id, translation in translations.items(): + matching_instance = None + for shout in shouts: + if shout["_id"] == shout_id: + matching_instance = shout + break + filter_expression = MongoFilter("_id", shout_id) + if not matching_instance.get("translations"): + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=filter_expression, + data={"translations": {}}, + ) + # English is the default language, so it is treated as message text + if shout_data.get("lang", "en") == "en": + updated_shouts.setdefault(cid, []).append(shout_id) + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=filter_expression, + data={"message_lang": "en"}, + ) + bulk_update_setter = { + "message_text": translation, + "message_lang": "en", + } + else: + bulk_update_setter = { + f'translations.{shout_data["lang"]}': translation + } + # TODO: make a convenience wrapper to make bulk insertion easier to follow + bulk_update.append( + UpdateOne({"_id": shout_id}, {"$set": bulk_update_setter}) + ) + if bulk_update: + self._execute_query( + command=MongoCommands.BULK_WRITE, + data=bulk_update, + ) + return updated_shouts + + def save_tts_response( + self, shout_id, audio_data: str, lang: str = "en", gender: str = "female" + ) -> bool: + """ + Saves TTS Response under corresponding shout id + + :param shout_id: message id to consider + :param audio_data: base64 encoded audio data received + :param lang: language of speech (defaults to English) + :param gender: language gender (defaults to female) + + :return bool if saving was successful + """ + + audio_file_name = f"{shout_id}_{lang}_{gender}.wav" + try: + self.sftp_connector.put_file_object( + file_object=audio_data, save_to=f"audio/{audio_file_name}" + ) + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", shout_id), + data={f"audio.{lang}.{gender}": audio_file_name}, + ) + operation_success = True + except Exception as ex: + LOG.error(f"Failed to save TTS response to db - {ex}") + operation_success = False + return operation_success + + def save_stt_response(self, shout_id, message_text: str, lang: str = "en"): + """ + Saves STT Response under corresponding shout id + + :param shout_id: message id to consider + :param message_text: STT result transcript + :param lang: language of speech (defaults to English) + """ + try: + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", shout_id), + data={f"transcripts.{lang}": message_text}, + ) + except Exception as ex: + LOG.error(f"Failed to save STT response to db - {ex}") diff --git a/utils/database_utils/mongo_utils/queries/dao/users.py b/utils/database_utils/mongo_utils/queries/dao/users.py new file mode 100644 index 00000000..813cece5 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/users.py @@ -0,0 +1,217 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import copy +from time import time +from typing import Union + +from utils.common import generate_uuid, get_hash +from utils.logging_utils import LOG +from utils.database_utils.mongo_utils import ( + MongoCommands, + MongoDocuments, + MongoQuery, + MongoFilter, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.database_utils.mongo_utils.queries.constants import UserPatterns + + +class UsersDAO(MongoDocumentDAO): + + _default_user_preferences = {"tts": {}, "chat_language_mapping": {}} + + @property + def document(self): + return MongoDocuments.USERS + + def get_user(self, user_id=None, nickname=None) -> Union[dict, None]: + """ + Gets user data based on provided params + :param user_id: target user id + :param nickname: target user nickname + """ + if not (user_id or nickname): + LOG.warning("Neither user_id nor nickname was provided") + return + filter_data = {} + if user_id: + filter_data["_id"] = user_id + if nickname: + filter_data["nickname"] = nickname + user = self.get_item(filters=filter_data) + if user and not user.get("preferences"): + user["preferences"] = self._default_user_preferences + self.set_preferences( + user_id=user_id, preferences_mapping=user["preferences"] + ) + return user + + def fetch_users_from_prompt(self, prompt: dict): + """Fetches user ids detected in provided prompt""" + prompt_data = prompt["data"] + user_ids = prompt_data.get("participating_subminds", []) + return self.list_contains( + source_set=user_ids, + projection_attributes=[ + "first_name", + "last_name", + "nickname", + "is_bot", + "avatar", + ], + ) + + @staticmethod + def create_from_pattern( + source: UserPatterns, override_defaults: dict = None + ) -> dict: + """ + Creates user record based on provided pattern from UserPatterns + + :param source: source pattern from UserPatterns + :param override_defaults: to override default values (optional) + :returns user data populated with default values where necessary + """ + if not override_defaults: + override_defaults = {} + + matching_data = {**copy.deepcopy(source.value), **override_defaults} + + matching_data.setdefault("_id", generate_uuid(length=20)) + matching_data.setdefault("password", get_hash(generate_uuid())) + matching_data.setdefault("date_created", int(time())) + matching_data.setdefault("is_tmp", True) + + return matching_data + + def get_neon_data(self, skill_name: str = "neon") -> dict: + """ + Gets a user profile for the user 'Neon' and adds it to the users db if not already present + + :param db_controller: db controller instance + :param skill_name: Neon Skill to consider (defaults to neon - Neon Assistant) + + :return Neon AI data + """ + neon_data = self.get_user(nickname=skill_name) + if not neon_data: + neon_data = self._register_neon_skill_user(skill_name=skill_name) + return neon_data + + def _register_neon_skill_user(self, skill_name: str): + last_name = "AI" if skill_name == "neon" else skill_name.capitalize() + nickname = skill_name + neon_data = self.create_from_pattern( + source=UserPatterns.NEON, + override_defaults={"last_name": last_name, "nickname": nickname}, + ) + self.add_item(data=neon_data) + return neon_data + + def get_bot_data(self, nickname: str, context: dict = None) -> dict: + """ + Gets a user profile for the requested bot instance and adds it to the users db if not already present + + :param nickname: nickname of the bot provided + :param context: context with additional bot information (optional) + + :return Matching bot data + """ + if not context: + context = {} + nickname = nickname.split("-")[0] + bot_data = self.get_user(nickname=nickname) + if not bot_data: + bot_data = self._create_bot(nickname=nickname, context=context) + self.add_item(data=bot_data) + elif not bot_data.get("is_bot") == "1": + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", bot_data["_id"]), + data={"is_bot": "1"}, + ) + return bot_data + + def _create_bot(self, nickname: str, context: dict) -> dict: + bot_data = dict( + _id=generate_uuid(length=20), + first_name="Bot", + last_name=context.get("last_name", nickname.capitalize()), + avatar=context.get("avatar", ""), + password=get_hash(generate_uuid()), + nickname=nickname, + is_bot="1", + full_nickname=nickname, # we treat each bot instance with equal nickname as same instance + date_created=int(time()), + is_tmp=False, + ) + self.add_item(data=bot_data) + return bot_data + + def set_preferences(self, user_id, preferences_mapping: dict): + """Sets user preferences for specified user according to preferences mapping""" + if user_id: + try: + update_mapping = { + f"preferences.{key}": val + for key, val in preferences_mapping.items() + } + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", user_id), + data=update_mapping, + ) + except Exception as ex: + LOG.error(f"Failed to update preferences for user_id={user_id} - {ex}") + return preferences_mapping + + def create_guest(self, nano_token: str = None) -> dict: + """ + Creates unauthorized user and sets its credentials to cookies + + :param nano_token: nano token to append to user on creation + + :returns: generated UserData + """ + + guest_nickname = f"guest_{generate_uuid(length=8)}" + + if nano_token: + new_user = self.create_from_pattern( + source=UserPatterns.GUEST_NANO, + override_defaults=dict(nickname=guest_nickname, tokens=[nano_token]), + ) + else: + new_user = self.create_from_pattern( + source=UserPatterns.GUEST, + override_defaults=dict(nickname=guest_nickname), + ) + # TODO: consider adding partial TTL index for guest users + # https://www.mongodb.com/docs/manual/core/index-ttl/ + self.add_item(data=new_user) + return new_user diff --git a/utils/database_utils/mongo_utils/queries/mongo_queries.py b/utils/database_utils/mongo_utils/queries/mongo_queries.py new file mode 100644 index 00000000..17504928 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/mongo_queries.py @@ -0,0 +1,219 @@ +from typing import List, Tuple + +from utils.common import buffer_to_base64 +from .constants import UserPatterns, ConversationSkins +from .wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG + + +def get_translations(translation_mapping: dict) -> Tuple[dict, dict]: + """ + Gets translation from db based on provided mapping + + :param translation_mapping: mapping of cid to desired translation language + + :return translations fetched from db + """ + populated_translations = {} + missing_translations = {} + for cid, cid_data in translation_mapping.items(): + lang = cid_data.get("lang", "en") + shout_ids = cid_data.get("shouts", []) + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=cid + ) + if not conversation_data: + LOG.error(f"Failed to fetch conversation data - {cid}") + continue + shout_data = fetch_shout_data( + conversation_data=conversation_data, + shout_ids=shout_ids, + fetch_senders=False, + ) + shout_lang = "en" + if len(shout_data) == 1: + shout_lang = shout_data[0].get("message_lang", "en") + for shout in shout_data: + message_text = shout.get("message_text") + if shout_lang != "en" and lang == "en": + shout_text = message_text + else: + shout_text = shout.get("translations", {}).get(lang) + if shout_text and lang != "en": + populated_translations.setdefault(cid, {}).setdefault("shouts", {})[ + shout["_id"] + ] = shout_text + elif message_text: + missing_translations.setdefault(cid, {}).setdefault("shouts", {})[ + shout["_id"] + ] = message_text + if missing_translations.get(cid): + missing_translations[cid]["lang"] = lang + missing_translations[cid]["source_lang"] = shout_lang + return populated_translations, missing_translations + + +def fetch_message_data( + skin: ConversationSkins, + conversation_data: dict, + start_idx: int = 0, + limit: int = 100, + fetch_senders: bool = True, + start_message_id: str = None, +): + """Fetches message data based on provided conversation skin""" + message_data = fetch_shout_data( + conversation_data=conversation_data, + fetch_senders=fetch_senders, + start_idx=start_idx, + id_from=start_message_id, + limit=limit, + ) + for message in message_data: + message["message_type"] = "plain" + if skin == ConversationSkins.PROMPTS: + detected_prompts = list( + set(item.get("prompt_id") for item in message_data if item.get("prompt_id")) + ) + prompt_data = fetch_prompt_data( + cid=conversation_data["_id"], prompt_ids=detected_prompts + ) + if prompt_data: + detected_prompt_ids = [] + for prompt in prompt_data: + prompt["message_type"] = "prompt" + detected_prompt_ids.append(prompt["_id"]) + message_data = [ + message + for message in message_data + if message.get("prompt_id") not in detected_prompt_ids + ] + message_data.extend(prompt_data) + return sorted(message_data, key=lambda shout: int(shout["created_on"])) + + +def fetch_shout_data( + conversation_data: dict, + start_idx: int = 0, + limit: int = 100, + fetch_senders: bool = True, + id_from: str = None, + shout_ids: List[str] = None, +) -> List[dict]: + """ + Fetches shout data out of conversation data + + :param conversation_data: input conversation data + :param start_idx: message index to start from (sorted by recency) + :param limit: number of shouts to fetch + :param fetch_senders: to fetch shout senders data + :param id_from: message id to start from + :param shout_ids: list of shout ids to fetch + """ + if not shout_ids and conversation_data.get("chat_flow", None): + if id_from: + try: + start_idx = len(conversation_data["chat_flow"]) - conversation_data[ + "chat_flow" + ].index(id_from) + except ValueError: + LOG.warning("Matching start message id not found") + return [] + if start_idx == 0: + conversation_data["chat_flow"] = conversation_data["chat_flow"][ + start_idx - limit : + ] + else: + conversation_data["chat_flow"] = conversation_data["chat_flow"][ + -start_idx - limit : -start_idx + ] + shout_ids = [str(msg_id) for msg_id in conversation_data["chat_flow"]] + shouts = MongoDocumentsAPI.SHOUTS.fetch_shouts(shout_ids=shout_ids) + result = list() + if shouts and fetch_senders: + users_from_shouts = MongoDocumentsAPI.USERS.list_contains( + source_set=[shout["user_id"] for shout in shouts] + ) + for shout in shouts: + matching_user = users_from_shouts.get(shout["user_id"], {}) + if not matching_user: + matching_user = MongoDocumentsAPI.USERS.create_from_pattern( + UserPatterns.UNRECOGNIZED_USER + ) + else: + matching_user = matching_user[0] + matching_user.pop("password", None) + matching_user.pop("is_tmp", None) + shout["message_id"] = shout["_id"] + shout_data = {**shout, **matching_user} + result.append(shout_data) + shouts = result + return sorted(shouts, key=lambda user_shout: int(user_shout["created_on"])) + + +def fetch_prompt_data( + cid: str, + limit: int = 100, + id_from: str = None, + prompt_ids: List[str] = None, + fetch_user_data: bool = False, + created_from: int = None, +) -> List[dict]: + """ + Fetches prompt data out of conversation data + + :param cid: target conversation id + :param limit: number of prompts to fetch + :param id_from: prompt id to start from + :param prompt_ids: prompt ids to fetch + :param fetch_user_data: to fetch user data in the + :param created_from: timestamp to filter messages from + + :returns list of matching prompt data along with matching messages and users + """ + matching_prompts = MongoDocumentsAPI.PROMPTS.get_prompts( + cid=cid, + limit=limit, + id_from=id_from, + prompt_ids=prompt_ids, + created_from=created_from, + ) + for prompt in matching_prompts: + prompt["user_mapping"] = MongoDocumentsAPI.USERS.fetch_users_from_prompt(prompt) + prompt["message_mapping"] = MongoDocumentsAPI.SHOUTS.fetch_messages_from_prompt( + prompt + ) + if fetch_user_data: + for user in prompt.get("data", {}).get("participating_subminds", []): + try: + nick = prompt["user_mapping"][user][0]["nickname"] + except KeyError: + LOG.warning( + f'user_id - "{user}" was not detected setting it as nick' + ) + nick = user + for k in ( + "proposed_responses", + "submind_opinions", + "votes", + ): + msg_id = prompt["data"][k].pop(user, "") + if msg_id: + prompt["data"][k][nick] = ( + prompt["message_mapping"] + .get(msg_id, [{}])[0] + .get("message_text") + or msg_id + ) + prompt["data"]["participating_subminds"] = [ + prompt["user_mapping"][x][0]["nickname"] + for x in prompt["data"]["participating_subminds"] + ] + return sorted(matching_prompts, key=lambda _prompt: int(_prompt["created_on"])) + + +def add_shout(data: dict): + MongoDocumentsAPI.SHOUTS.add_item(data=data) + if cid := data.get("cid"): + shout_id = data["_id"] + MongoDocumentsAPI.CHATS.add_shout(cid=cid, shout_id=shout_id) diff --git a/utils/database_utils/mongo_utils/queries/wrapper.py b/utils/database_utils/mongo_utils/queries/wrapper.py new file mode 100644 index 00000000..2e77f349 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/wrapper.py @@ -0,0 +1,68 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# DAO Imports +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.database_utils.mongo_utils.queries.dao.users import UsersDAO +from utils.database_utils.mongo_utils.queries.dao.chats import ChatsDAO +from utils.database_utils.mongo_utils.queries.dao.shouts import ShoutsDAO +from utils.database_utils.mongo_utils.queries.dao.prompts import PromptsDAO + + +class MongoDAOGateway(type): + def __getattribute__(self, name): + item = super().__getattribute__(name) + try: + if issubclass(item, MongoDocumentDAO): + item = item( + db_controller=self.db_controller, sftp_connector=self.sftp_connector + ) + except: + pass + return item + + +class MongoDocumentsAPI(metaclass=MongoDAOGateway): + """ + Wrapper for DB commands execution + If getting attribute is triggered, initialises relevant instance of DAO handler and returns it + """ + + db_controller = None + sftp_connector = None + + USERS = UsersDAO + CHATS = ChatsDAO + SHOUTS = ShoutsDAO + PROMPTS = PromptsDAO + + @classmethod + def init(cls, db_controller, sftp_connector=None): + """Inits Singleton with specified database controller""" + cls.db_controller = db_controller + cls.sftp_connector = sftp_connector From 7b0832454f5012340986b0458e6624a930e854cd Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Mon, 29 Jan 2024 22:37:28 +0100 Subject: [PATCH 2/5] Fixed unittests --- chat_server/sio.py | 4 ++-- chat_server/tests/test_sio.py | 20 +++++++++---------- .../mongo_utils/queries/dao/users.py | 7 +++---- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/chat_server/sio.py b/chat_server/sio.py index efc7f6c1..2c0d6e49 100644 --- a/chat_server/sio.py +++ b/chat_server/sio.py @@ -181,7 +181,7 @@ async def user_message(sid, data): data["userID"] = neon_data["_id"] elif data["is_bot"] == "1": bot_data = MongoDocumentsAPI.USERS.get_bot_data( - nickname=data["userID"], context=data.get("context") + user_id=data["userID"], context=data.get("context") ) data["userID"] = bot_data["_id"] @@ -233,7 +233,7 @@ async def user_message(sid, data): new_shout_data["translations"][lang] = data["messageText"] mongo_queries.add_shout(data=new_shout_data) - if is_announcement == "0" and data["prompt_id"]: + if is_announcement == "0" and data.get("prompt_id"): is_ok = MongoDocumentsAPI.PROMPTS.add_shout_to_prompt(data) if is_ok: await sio.emit( diff --git a/chat_server/tests/test_sio.py b/chat_server/tests/test_sio.py index d0a0ca25..8ff5d39c 100644 --- a/chat_server/tests/test_sio.py +++ b/chat_server/tests/test_sio.py @@ -131,7 +131,6 @@ def test_neon_message(self): user_id = "neon" message_data = { "userID": "neon", - "messageID": message_id, "messageText": "Neon Test 123", "bot": "0", "cid": "-1", @@ -153,27 +152,26 @@ def test_neon_message(self): query={ "command": "find_one", "document": "shouts", - "data": {"_id": message_id}, + "data": {"user_id": neon["_id"]}, } ) self.assertIsNotNone(shout) self.assertIsInstance(shout, dict) db_controller.exec_query( query={ - "command": "delete_one", + "command": "delete_many", "document": "shouts", - "data": {"_id": message_id}, + "data": {"_id": neon["_id"]}, } ) @pytest.mark.usefixtures("create_server") def test_bot_message(self): - message_id = f"test_bot_message_{generate_uuid()}" user_id = f"test_bot_{generate_uuid()}" + message_text = f"Bot Test {generate_uuid()}" message_data = { "userID": user_id, - "messageID": message_id, - "messageText": "Bot Test 123", + "messageText": message_text, "bot": "1", "cid": "-1", "context": dict(first_name="The", last_name="Bot"), @@ -191,14 +189,14 @@ def test_bot_message(self): ) self.assertIsNotNone(bot) self.assertIsInstance(bot, dict) - self.assertTrue(bot["first_name"] == "The") + self.assertTrue(bot["first_name"] == "Bot") self.assertTrue(bot["last_name"] == "Bot") shout = db_controller.exec_query( query={ "command": "find_one", "document": "shouts", - "data": {"_id": message_id}, + "data": {"user_id": bot["_id"]}, } ) self.assertIsNotNone(shout) @@ -206,9 +204,9 @@ def test_bot_message(self): db_controller.exec_query( query={ - "command": "delete_one", + "command": "delete_many", "document": "shouts", - "data": {"_id": message_id}, + "data": {"user_id": bot["_id"]}, } ) db_controller.exec_query( diff --git a/utils/database_utils/mongo_utils/queries/dao/users.py b/utils/database_utils/mongo_utils/queries/dao/users.py index 813cece5..caf08790 100644 --- a/utils/database_utils/mongo_utils/queries/dao/users.py +++ b/utils/database_utils/mongo_utils/queries/dao/users.py @@ -133,22 +133,21 @@ def _register_neon_skill_user(self, skill_name: str): self.add_item(data=neon_data) return neon_data - def get_bot_data(self, nickname: str, context: dict = None) -> dict: + def get_bot_data(self, user_id: str, context: dict = None) -> dict: """ Gets a user profile for the requested bot instance and adds it to the users db if not already present - :param nickname: nickname of the bot provided + :param user_id: user id of the bot provided :param context: context with additional bot information (optional) :return Matching bot data """ if not context: context = {} - nickname = nickname.split("-")[0] + nickname = user_id.split("-")[0] bot_data = self.get_user(nickname=nickname) if not bot_data: bot_data = self._create_bot(nickname=nickname, context=context) - self.add_item(data=bot_data) elif not bot_data.get("is_bot") == "1": self._execute_query( command=MongoCommands.UPDATE_MANY, From a94d596b9c878d8eef5c82d24d8546fcb4223211 Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Wed, 31 Jan 2024 00:49:33 +0100 Subject: [PATCH 3/5] Addressed comments --- chat_server/blueprints/chat.py | 8 +++-- .../mongo_utils/queries/constants.py | 28 ++++++++++++++++ .../mongo_utils/queries/dao/abc.py | 18 ++++++++-- .../mongo_utils/queries/dao/shouts.py | 3 +- .../mongo_utils/queries/dao/users.py | 5 ++- .../mongo_utils/queries/mongo_queries.py | 33 +++++++++++++++++-- 6 files changed, 84 insertions(+), 11 deletions(-) diff --git a/chat_server/blueprints/chat.py b/chat_server/blueprints/chat.py index 2dffc683..ba854ee5 100644 --- a/chat_server/blueprints/chat.py +++ b/chat_server/blueprints/chat.py @@ -25,6 +25,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import warnings from typing import Optional from time import time @@ -67,9 +68,12 @@ async def new_conversation( :returns JSON response with new conversation data if added, 401 error message otherwise """ - + if conversation_id: + warnings.warn( + "Param conversation id is no longer considered", DeprecationWarning + ) conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( - search_str=[conversation_id, conversation_name], + search_str=[conversation_name], ) if conversation_data: return respond(f'Conversation "{conversation_name}" already exists', 400) diff --git a/utils/database_utils/mongo_utils/queries/constants.py b/utils/database_utils/mongo_utils/queries/constants.py index 0a41108c..95ae07ba 100644 --- a/utils/database_utils/mongo_utils/queries/constants.py +++ b/utils/database_utils/mongo_utils/queries/constants.py @@ -1,3 +1,31 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from enum import Enum diff --git a/utils/database_utils/mongo_utils/queries/dao/abc.py b/utils/database_utils/mongo_utils/queries/dao/abc.py index c432ec6e..c5abba3c 100644 --- a/utils/database_utils/mongo_utils/queries/dao/abc.py +++ b/utils/database_utils/mongo_utils/queries/dao/abc.py @@ -60,7 +60,14 @@ def list_contains( aggregate_result: bool = True, *args, **kwargs - ) -> dict: + ) -> dict[str, list] | list[str]: + """ + Lists items that are members of :param source_set under the :param key + :param key: attribute to query + :param source_set: collection of values to lookup + :param aggregate_result: to apply aggregation by key on result (defaults to True) + :return matching items + """ items = {} contains_filter = self._build_contains_filter(key=key, lookup_set=source_set) if contains_filter: @@ -105,7 +112,11 @@ def list_items( ] return items - def aggregate_items_by_key(self, key, items: list) -> dict: + def aggregate_items_by_key(self, key: str, items: list[dict]) -> dict: + """ + Aggregates list of dictionaries according to the provided key + :return dictionary mapping id -> list of matching items + """ aggregated_data = {} # TODO: consider Mongo DB aggregation API for item in items: @@ -134,7 +145,8 @@ def _build_contains_filter(self, key, lookup_set) -> MongoFilter | None: ) return mongo_filter - def add_item(self, data: dict): + def add_item(self, data: dict) -> bool: + """Inserts provided data into the object's document""" return self._execute_query(command=MongoCommands.INSERT_ONE, data=data) def get_item( diff --git a/utils/database_utils/mongo_utils/queries/dao/shouts.py b/utils/database_utils/mongo_utils/queries/dao/shouts.py index 490c674a..66ac246c 100644 --- a/utils/database_utils/mongo_utils/queries/dao/shouts.py +++ b/utils/database_utils/mongo_utils/queries/dao/shouts.py @@ -69,10 +69,11 @@ def fetch_messages_from_prompt(self, prompt: dict): message_ids.extend(list(prompt_data.get(column, {}).values())) return self.list_contains(source_set=message_ids) - def fetch_audio_data(self, message_id: str): + def fetch_audio_data(self, message_id: str) -> str | None: """ Fetches audio data from message :param message_id: message id to fetch + :returns base64 encoded audio data if any """ shout_data = self.get_item(item_id=message_id) if not shout_data: diff --git a/utils/database_utils/mongo_utils/queries/dao/users.py b/utils/database_utils/mongo_utils/queries/dao/users.py index caf08790..ef4188ad 100644 --- a/utils/database_utils/mongo_utils/queries/dao/users.py +++ b/utils/database_utils/mongo_utils/queries/dao/users.py @@ -71,7 +71,7 @@ def get_user(self, user_id=None, nickname=None) -> Union[dict, None]: ) return user - def fetch_users_from_prompt(self, prompt: dict): + def fetch_users_from_prompt(self, prompt: dict) -> dict[str, list]: """Fetches user ids detected in provided prompt""" prompt_data = prompt["data"] user_ids = prompt_data.get("participating_subminds", []) @@ -174,7 +174,7 @@ def _create_bot(self, nickname: str, context: dict) -> dict: def set_preferences(self, user_id, preferences_mapping: dict): """Sets user preferences for specified user according to preferences mapping""" - if user_id: + if user_id and preferences_mapping: try: update_mapping = { f"preferences.{key}": val @@ -187,7 +187,6 @@ def set_preferences(self, user_id, preferences_mapping: dict): ) except Exception as ex: LOG.error(f"Failed to update preferences for user_id={user_id} - {ex}") - return preferences_mapping def create_guest(self, nano_token: str = None) -> dict: """ diff --git a/utils/database_utils/mongo_utils/queries/mongo_queries.py b/utils/database_utils/mongo_utils/queries/mongo_queries.py index 17504928..6bb4bc29 100644 --- a/utils/database_utils/mongo_utils/queries/mongo_queries.py +++ b/utils/database_utils/mongo_utils/queries/mongo_queries.py @@ -1,6 +1,34 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + from typing import List, Tuple -from utils.common import buffer_to_base64 from .constants import UserPatterns, ConversationSkins from .wrapper import MongoDocumentsAPI from utils.logging_utils import LOG @@ -60,7 +88,7 @@ def fetch_message_data( limit: int = 100, fetch_senders: bool = True, start_message_id: str = None, -): +) -> list[dict]: """Fetches message data based on provided conversation skin""" message_data = fetch_shout_data( conversation_data=conversation_data, @@ -213,6 +241,7 @@ def fetch_prompt_data( def add_shout(data: dict): + """Records shout data and pushes its id to the relevant conversation flow""" MongoDocumentsAPI.SHOUTS.add_item(data=data) if cid := data.get("cid"): shout_id = data["_id"] From 6015ce181a8a89af54237499d1e9dceba6d8d0f3 Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Sat, 3 Feb 2024 23:22:41 +0100 Subject: [PATCH 4/5] Fixed issue with prompts --- utils/database_utils/mongo_utils/queries/dao/abc.py | 10 ---------- utils/database_utils/mongo_utils/queries/dao/chats.py | 5 ----- utils/database_utils/mongo_utils/queries/dao/users.py | 11 +---------- 3 files changed, 1 insertion(+), 25 deletions(-) diff --git a/utils/database_utils/mongo_utils/queries/dao/abc.py b/utils/database_utils/mongo_utils/queries/dao/abc.py index c5abba3c..d8b8c078 100644 --- a/utils/database_utils/mongo_utils/queries/dao/abc.py +++ b/utils/database_utils/mongo_utils/queries/dao/abc.py @@ -80,7 +80,6 @@ def list_contains( def list_items( self, filters: list[MongoFilter] = None, - projection_attributes: list = None, limit: int = None, result_as_cursor: bool = True, ) -> dict: @@ -88,7 +87,6 @@ def list_items( Lists items under provided document belonging to source set of provided column values :param filters: filters to consider (optional) - :param projection_attributes: list of value keys to return (optional) :param limit: limit number of returned attributes (optional) :param result_as_cursor: to return result as cursor (defaults to True) :returns results of FIND operation over the desired document according to applied filters @@ -102,14 +100,6 @@ def list_items( result_filters=result_filters, result_as_cursor=result_as_cursor, ) - # TODO: pymongo support projection only as aggregation API which is not yet implemented in project - if projection_attributes: - items = [ - {k: v} - for item in items - for k, v in item.items() - if k in projection_attributes - ] return items def aggregate_items_by_key(self, key: str, items: list[dict]) -> dict: diff --git a/utils/database_utils/mongo_utils/queries/dao/chats.py b/utils/database_utils/mongo_utils/queries/dao/chats.py index 53f02295..9b71eb3d 100644 --- a/utils/database_utils/mongo_utils/queries/dao/chats.py +++ b/utils/database_utils/mongo_utils/queries/dao/chats.py @@ -52,7 +52,6 @@ def get_conversation_data( column_identifiers: List[str] = None, limit: int = 1, allow_regex_search: bool = False, - projection_attributes: dict = None, include_private: bool = False, ) -> Union[None, dict]: """ @@ -61,7 +60,6 @@ def get_conversation_data( :param column_identifiers: desired column identifiers to look up :param limit: limit found conversations :param allow_regex_search: to allow search for matching entries that CONTAIN :param search_str - :param projection_attributes: mapping of attributes to project (optional) :param include_private: to include private conversations (defaults to False) """ if isinstance(search_str, str): @@ -90,7 +88,6 @@ def get_conversation_data( value=or_expression, logical_operator=MongoLogicalOperators.OR ) ], - projection_attributes=projection_attributes, limit=limit, result_as_cursor=False, include_private=include_private, @@ -112,7 +109,6 @@ def add_shout(self, cid: str, shout_id: str): def list_items( self, filters: list[MongoFilter] = None, - projection_attributes: list = None, limit: int = None, result_as_cursor: bool = True, include_private: bool = False, @@ -122,7 +118,6 @@ def list_items( filters.append(MongoFilter(key="is_private", value=False)) return super().list_items( filters=filters, - projection_attributes=projection_attributes, limit=limit, result_as_cursor=result_as_cursor, ) diff --git a/utils/database_utils/mongo_utils/queries/dao/users.py b/utils/database_utils/mongo_utils/queries/dao/users.py index ef4188ad..d96e07c1 100644 --- a/utils/database_utils/mongo_utils/queries/dao/users.py +++ b/utils/database_utils/mongo_utils/queries/dao/users.py @@ -75,16 +75,7 @@ def fetch_users_from_prompt(self, prompt: dict) -> dict[str, list]: """Fetches user ids detected in provided prompt""" prompt_data = prompt["data"] user_ids = prompt_data.get("participating_subminds", []) - return self.list_contains( - source_set=user_ids, - projection_attributes=[ - "first_name", - "last_name", - "nickname", - "is_bot", - "avatar", - ], - ) + return self.list_contains(source_set=user_ids) @staticmethod def create_from_pattern( From bc32d749621bd4a653d6d16d44369cead926cd4a Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Sun, 18 Feb 2024 18:13:52 +0100 Subject: [PATCH 5/5] removed redundant property from request_tts --- chat_server/sio.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chat_server/sio.py b/chat_server/sio.py index 2c0d6e49..4216b2ab 100644 --- a/chat_server/sio.py +++ b/chat_server/sio.py @@ -478,14 +478,12 @@ async def request_tts(sid, data): required_keys = ( "cid", "message_id", - "user_id", ) if not all(key in list(data) for key in required_keys): LOG.error(f"Missing one of the required keys - {required_keys}") else: lang = data.get("lang", "en") message_id = data["message_id"] - user_id = data["user_id"] cid = data["cid"] matching_message = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) if not matching_message: