Skip to content

Commit

Permalink
Implement websocket client API (#20)
Browse files Browse the repository at this point in the history
# Description
Implements a websocket API for a Node client
Adds `ClientPermissions` object to define per-client permissions

# Issues
Closes #6

# Other Notes
Example client implementation:
NeonGeckoCom/neon-nodes#14

---------

Co-authored-by: Daniel McKnight <[email protected]>
  • Loading branch information
NeonDaniel and NeonDaniel authored May 21, 2024
1 parent 1429843 commit 589da9b
Show file tree
Hide file tree
Showing 14 changed files with 636 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ COPY docker_overlay/ /

WORKDIR /app
COPY . /app
RUN pip install /app
RUN pip install /app[websocket]

CMD ["python3", "/app/neon_hana/app/__main__.py"]
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ hana:
stt_max_length_encoded: 500000 # Arbitrary limit that is larger than any expected voice command
tts_max_words: 128 # Arbitrary limit that is longer than any default LLM token limit
enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address

node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access
node_password: node_password # Password associated with node_username
```
It is recommended to generate unique values for configured tokens, these are 32
bytes in hexadecimal representation.
Expand Down
2 changes: 2 additions & 0 deletions neon_hana/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from neon_hana.app.routers.mq_backend import mq_route
from neon_hana.app.routers.auth import auth_route
from neon_hana.app.routers.util import util_route
from neon_hana.app.routers.node_server import node_route
from neon_hana.version import __version__


Expand All @@ -47,5 +48,6 @@ def create_app(config: dict):
app.include_router(mq_route)
app.include_router(llm_route)
app.include_router(util_route)
app.include_router(node_route)

return app
75 changes: 75 additions & 0 deletions neon_hana/app/routers/node_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2021 Neongecko.com Inc.
# BSD-3
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from asyncio import Event
from signal import signal, SIGINT
from typing import Optional, Union

from fastapi import APIRouter, WebSocket, HTTPException, Request
from starlette.websockets import WebSocketDisconnect

from neon_hana.app.dependencies import config, client_manager
from neon_hana.mq_websocket_api import MQWebsocketAPI

from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt,
NodeGetTts, NodeKlatResponse,
NodeAudioInputResponse,
NodeGetSttResponse,
NodeGetTtsResponse)
node_route = APIRouter(prefix="/node", tags=["node"])

socket_api = MQWebsocketAPI(config)
signal(SIGINT, socket_api.shutdown)


@node_route.websocket("/v1")
async def node_v1_endpoint(websocket: WebSocket, token: str):
client_id = client_manager.get_client_id(token)
if not client_manager.validate_auth(token, client_id):
raise HTTPException(status_code=403,
detail="Invalid or expired token.")
if not client_manager.get_permissions(client_id).node:
raise HTTPException(status_code=401,
detail=f"Client not authorized for node access "
f"({client_id})")
await websocket.accept()
disconnect_event = Event()

socket_api.new_connection(websocket, client_id)
while not disconnect_event.is_set():
try:
client_in: dict = await websocket.receive_json()
socket_api.handle_client_input(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()


@node_route.get("/v1/doc")
async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt,
NodeGetTts]]) -> \
Optional[Union[NodeKlatResponse, NodeAudioInputResponse,
NodeGetSttResponse, NodeGetTtsResponse]]:
pass
52 changes: 51 additions & 1 deletion neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jwt import DecodeError
from ovos_utils import LOG
from token_throttler import TokenThrottler, TokenBucket
from token_throttler.storage import RuntimeStorage

from neon_hana.auth.permissions import ClientPermissions


class ClientManager:
def __init__(self, config: dict):
Expand All @@ -48,9 +51,15 @@ def __init__(self, config: dict):
self._rpm = config.get("requests_per_minute", 60)
self._auth_rpm = config.get("auth_requests_per_minute", 6)
self._disable_auth = config.get("disable_auth")
self._node_username = config.get("node_username")
self._node_password = config.get("node_password")
self._jwt_algo = "HS256"

def _create_tokens(self, encode_data: dict) -> dict:
# Permissions were not included in old tokens, allow refreshing with
# default permissions
encode_data.setdefault("permissions", ClientPermissions().as_dict())

token_expiration = encode_data['expire']
token = jwt.encode(encode_data, self._access_secret, self._jwt_algo)
encode_data['expire'] = time() + self._refresh_token_lifetime
Expand All @@ -59,13 +68,38 @@ def _create_tokens(self, encode_data: dict) -> dict:
# TODO: Store refresh token on server to allow invalidating clients
return {"username": encode_data['username'],
"client_id": encode_data['client_id'],
"permissions": encode_data['permissions'],
"access_token": token,
"refresh_token": refresh,
"expiration": token_expiration}

def get_permissions(self, client_id: str) -> ClientPermissions:
"""
Get ClientPermissions model for the given client_id
@param client_id: Client ID to get permissions for
@return: ClientPermissions object for the specified client
"""
if self._disable_auth:
LOG.debug("Auth disabled, allow full client permissions")
return ClientPermissions(assist=True, backend=True, node=True)
if client_id not in self.authorized_clients:
LOG.warning(f"{client_id} not known to this server")
return ClientPermissions(assist=False, backend=False, node=False)
client = self.authorized_clients[client_id]
return ClientPermissions(**client.get('permissions', dict()))

def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None,
origin_ip: str = "127.0.0.1"):
origin_ip: str = "127.0.0.1") -> dict:
"""
Authenticate and Authorize a new client connection with the specified
username, password, and origin IP address.
@param client_id: Client ID of the connection to auth
@param username: Supplied username to authenticate
@param password: Supplied password to authenticate
@param origin_ip: Origin IP address of request
@return: response tokens, permissions, and other metadata
"""
if client_id in self.authorized_clients:
print(f"Using cached client: {self.authorized_clients[client_id]}")
return self.authorized_clients[client_id]
Expand All @@ -84,13 +118,19 @@ def check_auth_request(self, client_id: str, username: str,
detail=f"Too many auth requests from: "
f"{origin_ip}. Wait {wait_time}s.")

node_access = False
if username != "guest":
# TODO: Validate password here
pass
if all((self._node_username, username == self._node_username,
password == self._node_password)):
node_access = True
permissions = ClientPermissions(node=node_access)
expiration = time() + self._access_token_lifetime
encode_data = {"client_id": client_id,
"username": username,
"password": password,
"permissions": permissions.as_dict(),
"expire": expiration}
auth = self._create_tokens(encode_data)
self.authorized_clients[client_id] = auth
Expand Down Expand Up @@ -125,6 +165,15 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
new_auth = self._create_tokens(encode_data)
return new_auth

def get_client_id(self, token: str) -> str:
"""
Extract the client_id from a JWT token
@param token: JWT token to parse
@return: client_id associated with token
"""
auth = jwt.decode(token, self._access_secret, self._jwt_algo)
return auth['client_id']

def validate_auth(self, token: str, origin_ip: str) -> bool:
if not self.rate_limiter.get_all_buckets(origin_ip):
self.rate_limiter.add_bucket(origin_ip,
Expand All @@ -142,6 +191,7 @@ def validate_auth(self, token: str, origin_ip: str) -> bool:
if auth['expire'] < time():
self.authorized_clients.pop(auth['client_id'], None)
return False
self.authorized_clients[auth['client_id']] = auth
return True
except DecodeError:
# Invalid token supplied
Expand Down
43 changes: 43 additions & 0 deletions neon_hana/auth/permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2021 Neongecko.com Inc.
# BSD-3
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from dataclasses import dataclass, asdict


@dataclass
class ClientPermissions:
"""
Data class representing permissions of a particular client connection.
"""
assist: bool = True
backend: bool = True
node: bool = False

def as_dict(self) -> dict:
"""
Get a dict representation of this instance.
"""
return asdict(self)
Loading

0 comments on commit 589da9b

Please sign in to comment.