Skip to content

Commit

Permalink
Update TokenConfig and add JWT models (#2)
Browse files Browse the repository at this point in the history
* Add BaseModel for JWT and `HanaToken` (to use for auth and refresh)
Update TokenConfig model to reflect expected database schema
Refactor AccessRoles values to allow for more in-between values
Add methods to read/dump PermissionsConfig to JWT-compatible role strings

* Update TokenConfig to better accept alternate named parameters with comment
Update unit test to account for token config change

* Remove aliased params in `HanaToken` class

* Fix syntax error in annotated fields

* Update token timestamp descriptions for accuracy

* Update TokenConfig to handle JWT standard field names
Add test coverage for TokenConfig

* Update comment in TokenConfig

* Mark `TokenConfig` as deprecated
Refactor `HanaToken` to include params previously used in `TokenConfig`

* Update tests to reflect changes to `HanaToken`

* Update token to include initial creation timestamp for database use
Annotate fields of HanaToken

* Update deprecation warning to use builtin decorator

* Update JWT base model to remove default values causing false positive validation results
Update unit tests to account for token changes

* Add jwt imports to api module for consistency with other submodules
  • Loading branch information
NeonDaniel authored Nov 19, 2024
1 parent 42a3f80 commit c3bc013
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 22 deletions.
8 changes: 5 additions & 3 deletions neon_data_models/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class AccessRoles(IntEnum):
NONE = 0
GUEST = 1
USER = 2
ADMIN = 3
OWNER = 4

# 3-5 Reserved for "premium users"
ADMIN = 6
# 7-8 Reserved for "restricted owners"
OWNER = 9
# 10 Reserved for "unlimited access"
NODE = -1


Expand Down
1 change: 1 addition & 0 deletions neon_data_models/models/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@

from neon_data_models.models.api.node_v1 import *
from neon_data_models.models.api.mq import *
from neon_data_models.models.api.jwt import *
86 changes: 86 additions & 0 deletions neon_data_models/models/api/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2024 Neongecko.com Inc.
# BSD-3
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from time import time
from typing import Optional, List, Literal
from uuid import uuid4

from pydantic import Field
from neon_data_models.enum import AccessRoles
from neon_data_models.models.base import BaseModel


class JWT(BaseModel):
iss: Optional[str] = Field(None, validate_default=True,
description="Token issuer")
sub: Optional[str] = Field(None, validate_default=True,
description="Unique token subject, ie a user ID")
exp: int = Field(description="Expiration time in epoch seconds")
iat: int = Field(description="Token creation time in epoch seconds")
jti: str = Field(description="Unique token identifier",
default_factory=lambda: str(uuid4()))

client_id: str = Field(description="Client identifier")
roles: List[str] = Field(description="List of roles, "
"formatted as `<name> <AccessRole>`. "
"See PermissionsConfig for role names")


class HanaToken(JWT):
def __init__(self, **kwargs):
from neon_data_models.models.user import PermissionsConfig
permissions = kwargs.get("permissions")
if permissions and isinstance(permissions, PermissionsConfig):
kwargs["roles"] = permissions.to_roles()
elif permissions and isinstance(permissions, dict):
core_permissions = AccessRoles.GUEST if \
permissions.get("assist") else AccessRoles.NONE
diana_permissions = AccessRoles.GUEST if \
permissions.get("backend") else AccessRoles.NONE
node_permissions = AccessRoles.USER if \
permissions.get("node") else AccessRoles.NONE
kwargs["roles"] = [f"core {core_permissions.value}",
f"diana {diana_permissions.value}",
f"node {node_permissions.value}"]
if kwargs.get("expire") and isinstance(kwargs["expire"], float):
kwargs["expire"] = round(kwargs["expire"])
BaseModel.__init__(self, **kwargs)

# Private parameters
token_name: str = Field(default="",
description="Friendly name to identify this token.")
creation_timestamp: int = Field(default_factory=lambda: int(time()),
description="Timestamp of initial token "
"creation (not counting "
"refreshes).")
last_refresh_timestamp: Optional[int] = Field(default=None,
description="Timestamp of "
"most recent "
"refresh.")
purpose: Literal["access", "refresh"] = "access"


__all__ = [JWT.__name__, HanaToken.__name__]
7 changes: 3 additions & 4 deletions neon_data_models/models/api/node_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class TextInputData(BaseModel):

class NodeKlatResponse(BaseMessage):
msg_type: Literal["klat.response"] = "klat.response"
data: Dict[str, KlatResponse] = Field(type=Dict[str, KlatResponse],
description="dict of BCP-47 language: KlatResponse")
data: Dict[str, KlatResponse] = Field(
description="dict of BCP-47 language: KlatResponse")


class NodeAudioInputResponse(BaseMessage):
Expand All @@ -97,8 +97,7 @@ class NodeGetSttResponse(BaseMessage):
class NodeGetTtsResponse(BaseMessage):
msg_type: Literal["neon.get_tts.response"] = "neon.get_tts.response"
data: Dict[str, KlatResponse] = (
Field(type=Dict[str, KlatResponse],
description="dict of BCP-47 language: KlatResponse"))
Field(description="dict of BCP-47 language: KlatResponse"))


class CoreWWDetected(BaseMessage):
Expand Down
35 changes: 29 additions & 6 deletions neon_data_models/models/user/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

from time import time
from typing import Dict, Any, List, Literal, Optional
from typing_extensions import deprecated
from uuid import uuid4

from neon_data_models.models.api.jwt import HanaToken
from neon_data_models.models.base import BaseModel
from pydantic import Field
from datetime import date
Expand Down Expand Up @@ -124,7 +127,28 @@ class PermissionsConfig(BaseModel):
class Config:
use_enum_values = True


@classmethod
def from_roles(cls, roles: List[str]):
"""
Parse PermissionsConfig from standard JWT roles configuration.
"""
kwargs = {}
for role in roles:
name, value = role.split(' ')
kwargs[name] = AccessRoles[value.upper()]
return cls(**kwargs)

def to_roles(self):
"""
Dump a PermissionsConfig to standard JWT roles to be included in a JWT.
"""
roles = []
for key, val in self.model_dump().items():
roles.append(f"{key} {AccessRoles(val).name}")
return roles


@deprecated(f"Use `neon_data_models.models.api.jwt.HanaToken`")
class TokenConfig(BaseModel):
username: str
client_id: str
Expand All @@ -136,9 +160,9 @@ class TokenConfig(BaseModel):
description="Unix timestamp of refresh token expiration")
token_name: str
creation_timestamp: int = Field(
description="Unix timestamp of auth token creation")
description="Unix timestamp of token creation (auth+refresh)")
last_refresh_timestamp: int = Field(
description="Unix timestamp of last auth token refresh")
description="Unix timestamp of last token refresh (auth+refresh)")
access_token: Optional[str] = None


Expand All @@ -151,12 +175,11 @@ class User(BaseModel):
klat: KlatConfig = KlatConfig()
llm: BrainForgeConfig = BrainForgeConfig()
permissions: PermissionsConfig = PermissionsConfig()
tokens: Optional[List[TokenConfig]] = []
tokens: Optional[List[HanaToken]] = []

def __eq__(self, other):
return self.model_dump() == other.model_dump()


__all__ = [NeonUserConfig.__name__, KlatConfig.__name__,
BrainForgeConfig.__name__, PermissionsConfig.__name__,
TokenConfig.__name__, User.__name__]
BrainForgeConfig.__name__, PermissionsConfig.__name__, User.__name__]
81 changes: 72 additions & 9 deletions tests/models/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@

from time import time
from unittest import TestCase
from uuid import uuid4

from pydantic import ValidationError
from datetime import date
from neon_data_models.models.user.database import NeonUserConfig, TokenConfig, User

from neon_data_models.models.api.jwt import HanaToken
from neon_data_models.models.user.database import NeonUserConfig, User, PermissionsConfig


class TestDatabase(TestCase):
Expand Down Expand Up @@ -83,17 +87,17 @@ def test_neon_user_config(self):
def test_user(self):
user_kwargs = dict(username="test",
password_hash="test",
tokens=[{"username": "test",
"client_id": "test_id",
"permissions": {},
"refresh_token": "",
"expiration": round(time()),
"refresh_expiration": round(time()),
"token_name": "test_token",
tokens=[{"token_name": "test_token",
"jti": str(uuid4()),
"sub": str(uuid4()),
"client_id": str(uuid4()),
"roles": PermissionsConfig().to_roles(),
"iat": round(time()) - 1,
"exp": round(time()) + 1,
"creation_timestamp": round(time()),
"last_refresh_timestamp": round(time())}])
default_user = User(**user_kwargs)
self.assertIsInstance(default_user.tokens[0], TokenConfig)
self.assertIsInstance(default_user.tokens[0], HanaToken)
with self.assertRaises(ValidationError):
User()

Expand All @@ -105,6 +109,65 @@ def test_user(self):
self.assertNotEqual(default_user, duplicate_user)
self.assertEqual(default_user.tokens, duplicate_user.tokens)

def test_permissions_config(self):
from neon_data_models.models.user.database import PermissionsConfig
from neon_data_models.enum import AccessRoles

# Test Default
default_config = PermissionsConfig()
for _, value in default_config.model_dump().items():
self.assertEqual(value, AccessRoles.NONE)

test_config = PermissionsConfig(klat=AccessRoles.USER,
core=AccessRoles.GUEST,
diana=AccessRoles.GUEST,
node=AccessRoles.NODE,
hub=AccessRoles.NODE,
llm=AccessRoles.NONE)
# Test dump/load
self.assertEqual(PermissionsConfig(**test_config.model_dump()),
test_config)

# Test to/from roles
roles = test_config.to_roles()
self.assertIsInstance(roles, list)
for role in roles:
self.assertEqual(len(role.split()), 2)
self.assertEqual(PermissionsConfig.from_roles(roles), test_config)

def test_token_config(self):
from neon_data_models.models.user.database import PermissionsConfig
token_id = str(uuid4())
user_id = str(uuid4())
client_id = str(uuid4())
token_name = "Test Token"
permissions = PermissionsConfig()
refresh_expiration = round(time()) + 3600
creation = round(time()) - 3600
last_refresh = round(time())

from_database = HanaToken(token_name=token_name,
jti=token_id,
sub=user_id,
client_id=client_id,
roles=permissions.to_roles(),
exp=refresh_expiration,
iat=creation,
last_refresh_timestamp=last_refresh)

from_token = HanaToken(jti=token_id,
sub=user_id,
iat=creation,
exp=refresh_expiration,
token_name=token_name,
client_id=client_id,
permissions=permissions,
last_refresh_timestamp=last_refresh)

self.assertEqual(from_database, from_token)
self.assertEqual(from_database.model_dump_json(),
from_token.model_dump_json())


class TestNeonProfile(TestCase):
def test_create(self):
Expand Down

0 comments on commit c3bc013

Please sign in to comment.