Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for admin api #72

Merged
merged 15 commits into from
Dec 4, 2023
Merged
46 changes: 46 additions & 0 deletions chat_server/blueprints/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from fastapi import APIRouter
NeonKirill marked this conversation as resolved.
Show resolved Hide resolved
from starlette.requests import Request

from utils.logging_utils import LOG
from utils.http_utils import respond

from chat_server.server_config import mq_api, mq_management_config, k8s_config
from chat_server.sio import login_required
from chat_server.server_utils.k8s_utils import restart_deployment
from chat_server.server_utils.admin_utils import run_mq_validation

router = APIRouter(
prefix="/admin",
responses={"404": {"description": "Unknown authorization endpoint"}},
)


@router.post("/refresh/{service_name}")
@login_required(tmp_allowed=False, required_roles=["admin"])
async def refresh_state(
request: Request, service_name: str, target_items: str | None = ""
):
"""
Refreshes state of the target

:param request: Starlette Request Object
:param service_name: name of service to refresh
:param target_items: comma-separated list of items to refresh

:returns JSON-formatted response from server
"""
target_items = [x for x in target_items.split(",") if x]
if service_name == "k8s":
if not k8s_config:
return respond("K8S Service Unavailable", 503)
deployments = target_items
if deployments == "*":
deployments = k8s_config.get("MANAGED_DEPLOYMENTS", [])
LOG.info(f"Restarting {deployments=!r}")
for deployment in deployments:
restart_deployment(deployment_name=deployment)
elif service_name == "mq":
run_mq_validation()
else:
return respond(f"Unknown refresh type: {service_name!r}", 404)
return respond("OK")
15 changes: 14 additions & 1 deletion chat_server/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@

import os

from utils.logging_utils import LOG

from config import Configuration
from chat_server.server_utils.sftp_utils import init_sftp_connector
from utils.logging_utils import LOG
from chat_server.server_utils.rmq_utils import RabbitMQAPI

server_config_path = os.environ.get(
"CHATSERVER_CONFIG", "~/.local/share/neon/credentials.json"
Expand All @@ -50,3 +52,14 @@
db_controller = config.get_db_controller(name="pyklatchat_3333")

sftp_connector = init_sftp_connector(config=app_config.get("SFTP", {}))

mq_api = None
mq_management_config = config.get("MQ_MANAGEMENT", {})
if mq_management_url := mq_management_config.get("MQ_MANAGEMENT_URL"):
mq_api = RabbitMQAPI(url=mq_management_url)
mq_api.login(
username=mq_management_config["MQ_MANAGEMENT_LOGIN"],
password=mq_management_config["MQ_MANAGEMENT_PASSWORD"],
)

k8s_config = config.get("K8S_CONFIG", {})
25 changes: 25 additions & 0 deletions chat_server/server_utils/admin_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from chat_server.server_config import mq_api, mq_management_config, LOG
NeonKirill marked this conversation as resolved.
Show resolved Hide resolved


def run_mq_validation():
if mq_api:
for vhost in mq_management_config.get("VHOSTS", []):
status = mq_api.add_vhost(vhost=vhost["name"])
if not status.ok:
raise ConnectionError(f'Failed to add {vhost["name"]}, {status=}')
for user_creds in mq_management_config.get("USERS", []):
mq_api.add_user(
user=user_creds["name"],
password=user_creds["password"],
tags=user_creds.get("tags", ""),
)
for user_vhost_permissions in mq_management_config.get(
"USER_VHOST_PERMISSIONS", []
):
mq_api.configure_vhost_user_permissions(**user_vhost_permissions)
else:
LOG.error("MQ API is unavailable")


if __name__ == "__main__":
run_mq_validation()
23 changes: 18 additions & 5 deletions chat_server/server_utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def refresh_session(payload: dict):


def validate_session(
request: Union[str, Request], check_tmp: bool = False, sio_request: bool = False
request: Union[str, Request],
check_tmp: bool = False,
required_roles: list = None,
sio_request: bool = False,
) -> Tuple[str, int]:
"""
Check if session token contained in request is valid
Expand All @@ -268,12 +271,20 @@ def validate_session(
payload = jwt.decode(
jwt=session, key=secret_key, algorithms=jwt_encryption_algo
)
if check_tmp:
should_check_user_data = check_tmp or required_roles
is_authorized = True
if should_check_user_data:
from chat_server.server_utils.db_utils import DbUtils

user = DbUtils.get_user(user_id=payload["sub"])
if user.get("is_tmp"):
return "Permission denied", 403
if check_tmp and user.get("is_tmp"):
is_authorized = False
elif required_roles and not any(
user_role in required_roles for user_role in user.get("roles", [])
):
is_authorized = False
if not is_authorized:
return "Permission denied", 403
if (int(time()) - int(payload.get("creation_time", 0))) <= session_lifetime:
return "OK", 200
return "Session Expired", 401
Expand All @@ -297,7 +308,9 @@ def outer(func):
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
session_validation_output = validate_session(
request, check_tmp=not outer_kwargs.get("tmp_allowed")
request,
check_tmp=not outer_kwargs.get("tmp_allowed"),
required_roles=outer_kwargs.get("required_roles"),
)
LOG.debug(
f"(url={request.url}) Received session validation output: {session_validation_output}"
Expand Down
10 changes: 5 additions & 5 deletions chat_server/server_utils/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def save_translations(cls, translation_mapping: dict) -> Dict[str, List[str]]:
filter_expression = {"_id": shout_id}
cls.db_controller.exec_query(
query=MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.SHOUTS,
filters=filter_expression,
data={"translations": {}},
Expand Down Expand Up @@ -546,7 +546,7 @@ def get_user_preferences(cls, user_id):
if user and not user.get("preferences"):
cls.db_controller.exec_query(
MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.USERS,
filters=MongoFilter(key="_id", value=user_id),
data={"preferences": prefs},
Expand All @@ -570,7 +570,7 @@ def set_user_preferences(cls, user_id, preferences_mapping: dict):
}
cls.db_controller.exec_query(
MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.USERS,
filters=MongoFilter("_id", user_id),
data=update_mapping,
Expand Down Expand Up @@ -603,7 +603,7 @@ def save_tts_response(
)
cls.db_controller.exec_query(
query=MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.SHOUTS,
filters=MongoFilter("_id", shout_id),
data={f"audio.{lang}.{gender}": audio_file_name},
Expand All @@ -628,7 +628,7 @@ def save_stt_response(cls, shout_id, message_text: str, lang: str = "en"):
try:
cls.db_controller.exec_query(
query=MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.SHOUTS,
filters=MongoFilter("_id", shout_id),
data={f"transcripts.{lang}": message_text},
Expand Down
53 changes: 53 additions & 0 deletions chat_server/server_utils/k8s_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import datetime
NeonKirill marked this conversation as resolved.
Show resolved Hide resolved
import os

from kubernetes import client, config
from kubernetes.client.rest import ApiException

from chat_server.server_config import k8s_config
from utils.logging_utils import LOG

k8s_app_api = None
_k8s_default_namespace = "default"

if k8s_config:
_k8s_config_path = k8s_config.get("K8S_CONFIG_PATH")
_k8s_default_namespace = (
k8s_config.get("K8S_DEFAULT_NAMESPACE") or _k8s_default_namespace
)
config.load_kube_config(k8s_config.get("K8S_CONFIG_PATH"))

k8s_app_api = client.AppsV1Api()
else:
LOG.warning("K8S config is unset!")


def restart_deployment(deployment_name: str, namespace: str = _k8s_default_namespace):
"""
Restarts K8S deployment
:param deployment_name: name of the deployment to restart
:param namespace: name of the namespace
"""
if not k8s_app_api:
LOG.error(
f"Failed to restart {deployment_name=!r} ({namespace=!r}) - missing K8S configs"
)
return -1
now = datetime.datetime.utcnow()
now = str(now.isoformat() + "Z")
body = {
"spec": {
"template": {
"metadata": {"annotations": {"kubectl.kubernetes.io/restartedAt": now}}
}
}
}
try:
k8s_app_api.patch_namespaced_deployment(
deployment_name, namespace, body, pretty="true"
)
except ApiException as e:
LOG.error(
"Exception when calling AppsV1Api->read_namespaced_deployment_status: %s\n"
% e
)
4 changes: 2 additions & 2 deletions chat_server/server_utils/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def handle_prompt_message(message: dict) -> bool:
}
db_controller.exec_query(
MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.PROMPTS,
filters=MongoFilter(key="_id", value=prompt_id),
**data_kwargs,
Expand Down Expand Up @@ -128,7 +128,7 @@ def handle_prompt_message(message: dict) -> bool:
}
db_controller.exec_query(
MongoQuery(
command=MongoCommands.UPDATE,
command=MongoCommands.UPDATE_MANY,
document=MongoDocuments.PROMPTS,
filters=MongoFilter(key="_id", value=prompt_id),
**data_kwargs,
Expand Down
Loading
Loading