diff --git a/diana_services_api/app/__main__.py b/diana_services_api/app/__main__.py index b975618..59bf4a6 100644 --- a/diana_services_api/app/__main__.py +++ b/diana_services_api/app/__main__.py @@ -32,7 +32,7 @@ def main(): - config = Configuration().get("diana_services_api") + config = Configuration().get("diana_services_api", {}) app = create_app(config) uvicorn.run(app, host=config.get('server_host', "0.0.0.0"), port=config.get('port', 8080)) diff --git a/diana_services_api/app/routers/auth.py b/diana_services_api/app/routers/auth.py index 6d1b8f1..374b9e8 100644 --- a/diana_services_api/app/routers/auth.py +++ b/diana_services_api/app/routers/auth.py @@ -35,3 +35,8 @@ @auth_route.post("/login") async def check_login(request: AuthenticationRequest) -> AuthenticationResponse: return client_manager.check_auth_request(**dict(request)) + + +@auth_route.post("/refresh") +async def check_refresh(request: RefreshRequest) -> AuthenticationResponse: + return client_manager.check_refresh_request(**dict(request)) diff --git a/diana_services_api/auth/client_manager.py b/diana_services_api/auth/client_manager.py index 3133491..ac39ecb 100644 --- a/diana_services_api/auth/client_manager.py +++ b/diana_services_api/auth/client_manager.py @@ -44,6 +44,17 @@ def __init__(self, config: dict): self._disable_auth = config.get("disable_auth") self._jwt_algo = "HS256" + def _create_tokens(self, encode_data: dict) -> dict: + token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) + encode_data['expire'] = time() + self._refresh_token_lifetime + encode_data['access_token'] = token + refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) + # TODO: Store refresh token on server to allow invalidating clients + return {"username": encode_data['username'], + "client_id": encode_data['client_id'], + "access_token": token, + "refresh_token": refresh} + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None): if client_id in self.authorized_clients: @@ -56,17 +67,41 @@ def check_auth_request(self, client_id: str, username: str, "username": username, "password": password, "expire": expiration} - token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) - encode_data['expire'] = time() + self._refresh_token_lifetime - refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) - # TODO: Store refresh token on server to validate refresh requests - auth = {"username": username, - "client_id": client_id, - "access_token": token, - "refresh_token": refresh} + auth = self._create_tokens(encode_data) self.authorized_clients[client_id] = auth return auth + def check_refresh_request(self, access_token: str, refresh_token: str, + client_id: str): + # Read and validate refresh token + try: + refresh_data = jwt.decode(refresh_token, self._refresh_secret, + self._jwt_algo) + except DecodeError: + raise HTTPException(status_code=400, + detail="Invalid refresh token supplied") + if refresh_data['access_token'] != access_token: + raise HTTPException(status_code=403, + detail="Refresh and access token mismatch") + if time() > refresh_data['expire']: + raise HTTPException(status_code=401, + detail="Refresh token is expired") + # Read access token and re-generate a new pair of tokens + try: + token_data = jwt.decode(access_token, self._access_secret, + self._jwt_algo) + except DecodeError: + raise HTTPException(status_code=400, + detail="Invalid access token supplied") + if token_data['client_id'] != client_id: + raise HTTPException(status_code=403, + detail="Access token does not match client_id") + encode_data = {k: token_data[k] for k in + ("client_id", "username", "password")} + encode_data["expire"] = time() + self._access_token_lifetime + new_auth = self._create_tokens(encode_data) + return new_auth + def validate_auth(self, token: str) -> bool: if self._disable_auth: return True diff --git a/diana_services_api/schema/auth_requests.py b/diana_services_api/schema/auth_requests.py index 3c58df0..eef7dfb 100644 --- a/diana_services_api/schema/auth_requests.py +++ b/diana_services_api/schema/auth_requests.py @@ -57,3 +57,9 @@ class AuthenticationResponse(BaseModel): "access_token": "", "refresh_token": "" }]}} + + +class RefreshRequest(BaseModel): + access_token: str + refresh_token: str + client_id: str