Skip to content

Commit

Permalink
Merge pull request #86 from ral-facilities/fix-auth-after-image-uploa…
Browse files Browse the repository at this point in the history
…d-#85

Force auth before UploadFile upload #85
  • Loading branch information
joelvdavies authored Jan 14, 2025
2 parents 225598b + adc65dc commit d51376a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import logging

import jwt
from fastapi import HTTPException, Request, status
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response

from object_storage_api.core.config import config
from object_storage_api.core.consts import PUBLIC_KEY
Expand All @@ -17,37 +21,36 @@

logger = logging.getLogger()

security = HTTPBearer(auto_error=True)

class JWTBearer(HTTPBearer):

class JWTMiddleware(BaseHTTPMiddleware):
"""
Extends the FastAPI `HTTPBearer` class to provide JSON Web Token (JWT) based authentication/authorization.
A middleware class to provide JSON Web Token (JWT) based authentication/authorization.
"""

def __init__(self, auto_error: bool = True) -> None:
"""
Initialize the `JWTBearer`.
:param auto_error: If `True`, it automatically raises `HTTPException` if the HTTP Bearer token is not provided
(in an `Authorization` header).
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
"""
super().__init__(auto_error=auto_error)
Performs JWT access token authentication/authorization before processing the request.
async def __call__(self, request: Request) -> str:
"""
Callable method for JWT access token authentication/authorization.
This method is called when `JWTBearer` is used as a dependency in a FastAPI route. It performs authentication/
authorization by calling the parent class method and then verifying the JWT access token.
:param request: The FastAPI `Request` object.
:param request: The Starlette `Request` object.
:param call_next: The next function to call to process the `Request` object.
:return: The JWT access token if authentication is successful.
:raises HTTPException: If the supplied JWT access token is invalid or has expired.
"""
credentials: HTTPAuthorizationCredentials = await super().__call__(request)

if not self._is_jwt_access_token_valid(credentials.credentials):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token or expired token")

return credentials.credentials
if request.url.path not in [f"{config.api.root_path}/docs", f"{config.api.root_path}/openapi.json"]:
try:
credentials: HTTPAuthorizationCredentials = await security(request)
except HTTPException as exc:
# Cannot raise HttpException here, so must do manually
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})

if not self._is_jwt_access_token_valid(credentials.credentials):
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN, content={"detail": "Invalid token or expired token"}
)

return await call_next(request)

def _is_jwt_access_token_valid(self, access_token: str) -> bool:
"""
Expand Down
25 changes: 11 additions & 14 deletions object_storage_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,21 @@ async def custom_general_exception_handler(_: Request, exc: Exception) -> JSONRe
# pylint:disable=fixme
# TODO: The auth code in this file is identical to the one in inventory-management-system-api - Use common repo?

router_dependencies = []

def get_router_dependencies() -> list:
"""
Get the list of dependencies for the API routers.
:return: List of dependencies
"""
dependencies = []
# Include the `JWTBearer` as a dependency if authentication is enabled
if config.authentication.enabled is True:
# pylint:disable=import-outside-toplevel
from object_storage_api.auth.jwt_bearer import JWTBearer
# Add authentication middleware & dependency if enabled
if config.authentication.enabled is True:
# pylint:disable=import-outside-toplevel
from object_storage_api.auth.jwt_middleware import JWTMiddleware, security

dependencies.append(Depends(JWTBearer()))
return dependencies
app.add_middleware(JWTMiddleware)

# This router dependency is still needed for the swagger docs show the authorise button, even though the actual
# auth is done in the middleware
router_dependencies.append(Depends(security))


# Middlewares act in reverse order
app.add_middleware(
CORSMiddleware,
allow_origins=config.api.allowed_cors_origins,
Expand All @@ -96,8 +95,6 @@ def get_router_dependencies() -> list:
allow_headers=config.api.allowed_cors_headers,
)

router_dependencies = get_router_dependencies()

app.include_router(attachment.router, dependencies=router_dependencies)
app.include_router(image.router, dependencies=router_dependencies)

Expand Down
39 changes: 37 additions & 2 deletions test/e2e/test_jwt_bearer.py → test/e2e/test_jwt_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
from fastapi.routing import APIRoute


def test_jwt_middleware_allows_authenticated_request(test_client):
"""
Test the `JWTMiddleware` dependency appropriately allows a request with a valid bearer token.
"""

response = test_client.request("GET", "/attachments", headers={"Authorization": f"Bearer {VALID_ACCESS_TOKEN}"})
assert response.status_code == 200


@pytest.mark.parametrize(
"headers, expected_response_message",
[
Expand Down Expand Up @@ -51,9 +60,9 @@
),
],
)
def test_jwt_bearer_authorization_request(test_client, headers, expected_response_message):
def test_jwt_middleware_denies_unauthenticated_requests(test_client, headers, expected_response_message):
"""
Test the `JWTBearer` routers' dependency on all the API routes.
Test the `JWTMiddleware` dependency appropriately denies all requests without proper authentication.
"""
api_routes = [
api_route for api_route in test_client.app.routes if isinstance(api_route, APIRoute) and api_route.path != "/"
Expand All @@ -65,3 +74,29 @@ def test_jwt_bearer_authorization_request(test_client, headers, expected_respons
response = test_client.request(method, api_route.path, headers=headers)
assert response.status_code == 403
assert response.json()["detail"] == expected_response_message


@pytest.mark.parametrize(
"route_path",
[
pytest.param(
"/",
id="root",
),
pytest.param(
"/docs",
id="docs",
),
pytest.param(
"/openapi.json",
id="openapi.json",
),
],
)
def test_jwt_middleware_allows_unauthenticated_get_requests(test_client, route_path):
"""
Test the `JWTMiddleware` dependency appropriately allows GET requests at specific route paths.
"""

response = test_client.request("GET", route_path, headers={})
assert response.status_code == 200
Empty file removed test/unit/auth/__init__.py
Empty file.
133 changes: 0 additions & 133 deletions test/unit/auth/test_jwt_bearer.py

This file was deleted.

0 comments on commit d51376a

Please sign in to comment.