Skip to content

Commit

Permalink
fix: add support for asgi 3.0 while maintaining 2.0 support. (#1)
Browse files Browse the repository at this point in the history
* fix: add support for asgi 3.0 while maintaining 2.0 support.

* fix: OPTIONS does not make sense in Access-Control-Allow-Methods

* fix: non-CORS OPTIONS requests do not need CORS headers.

* fix: inclusion of safelisted headers has meaning and should not be done by default.

* style: 204 is fine for successful preflights.

* fix: 403 looks to be the explicit response mentioned in the living fetch spec.

* fix: origin header only varies the response when there are multiple origins.

* test: expand coverage and test that non http forwarding works.

* fix: use asgiref to do asgi2 compatibility.
  • Loading branch information
jhillacre authored Aug 1, 2023
1 parent 3d70de6 commit 84c7351
Show file tree
Hide file tree
Showing 6 changed files with 900 additions and 109 deletions.
3 changes: 2 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[run]
branch = True
source = asgi_middleware
source = asgi_cors_middleware

[report]
exclude_lines =
Expand All @@ -20,5 +20,6 @@ omit =
setup.py
*/distutils/*
tests/*
asgi_cors_middleware/__init__.py

show_missing = True
52 changes: 30 additions & 22 deletions asgi_cors_middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import re
import typing

from asgiref.compatibility import guarantee_single_callable
from starlette.datastructures import Headers, MutableHeaders

from starlette.responses import PlainTextResponse
from starlette.responses import Response

ALL_METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT")
# OPTIONS doesn't make sense to return as an allowed method for CORS.
# See https://stackoverflow.com/a/68529748
ALL_METHODS = ("DELETE", "GET", "PATCH", "POST", "PUT")
SAFELISTED_HEADERS = {
"Accept", "Accept-Language", "Content-Language", "Content-Type"
}
Expand Down Expand Up @@ -49,22 +53,25 @@ def __init__(
preflight_headers = {}
if "*" in origins:
preflight_headers["Access-Control-Allow-Origin"] = "*"
else:
elif len(origins) > 1 or compiled_allow_origin_regex is not None:
preflight_headers["Vary"] = "Origin"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
# re-including normally safelisted headers implies that you want to lift the browsers
# additional restrictions on those headers. we don't want to do that by default.
# See https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_request_header#additional_restrictions
allow_headers = sorted(set(allow_headers))
if allow_headers and "*" not in allow_headers:
preflight_headers["Access-Control-Allow-Headers"] = \
", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"

self.app = app
self.app = guarantee_single_callable(app)
self.origins = origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
Expand All @@ -77,24 +84,24 @@ def __init__(
async def __call__(
self, scope, receive, send
) -> None:
if scope["type"] != "http": # pragma: no cover
handler = await self.app(scope, receive, send)
await handler.__call__(receive, send)
return
if scope["type"] != "http":
return await self.app(scope, receive, send)

method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")

if origin is None:
handler = await self.app(scope, receive, send)
await handler.__call__(receive, send)
return
return await self.app(scope, receive, send)

if method == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
if method == "OPTIONS":
if "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
# if this is an options request but was not a cors preflight,
# we should skip the simple response processing.
return await self.app(scope, receive, send)

await self.simple_response(
scope, receive, send, request_headers=headers
Expand All @@ -110,7 +117,7 @@ def is_allowed_origin(self, origin: str) -> bool:

return any(host in origin for host in self.origins)

def preflight_response(self, request_headers) -> PlainTextResponse:
def preflight_response(self, request_headers) -> Response:
requested_origin = request_headers["origin"]
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get(
Expand All @@ -133,16 +140,17 @@ def preflight_response(self, request_headers) -> PlainTextResponse:
headers["Access-Control-Allow-Headers"] = requested_headers
elif requested_headers is not None:
for header in [h.lower() for h in requested_headers.split(",")]:
if header.strip() not in self.allow_headers:
requested_header = header.strip()
if requested_header not in self.allow_headers and requested_method not in SAFELISTED_HEADERS:
failures.append("headers")

if failures:
failure_text = "Disallowed CORS " + ", ".join(failures)
return PlainTextResponse(
failure_text, status_code=400, headers=headers
failure_text, status_code=403, headers=headers
)

return PlainTextResponse("OK", status_code=200, headers=headers)
return Response("", status_code=204, headers=headers)

async def simple_response(
self,
Expand All @@ -154,8 +162,7 @@ async def simple_response(
send = functools.partial(
self.send, send=send, request_headers=request_headers
)
handler = await self.app(scope, receive, send)
await handler(receive, send)
return await self.app(scope, receive, send)

async def send(self, message, send, request_headers) -> None:
if message["type"] != "http.response.start":
Expand All @@ -174,5 +181,6 @@ async def send(self, message, send, request_headers) -> None:
elif not self.allow_all_origins and \
self.is_allowed_origin(origin=origin):
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")
if len(self.origins) > 1 or self.allow_origin_regex is not None:
headers.add_vary_header("Origin")
await send(message)
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
asgiref==3.3.4
asgiref==3.7.2
coverage==5.5
pytest==6.2.4
pytest-cov==2.12.0
starlette==0.14.2
twine==3.4.1
wheel==0.36.2
pytest-asyncio==0.15.1
setuptools==56.2.0
setuptools==56.2.0
channels>=3.0.4,<4.0.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
],
packages=["asgi_cors_middleware"],
include_package_data=True,
install_requires=["starlette"]
install_requires=["starlette", "asgiref>=3.5.0"]
)
72 changes: 62 additions & 10 deletions tests/asgi_app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,64 @@
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route

async def app(scope, receive, send):
if scope["type"] == "websocket":
await send({"type": "websocket.accept"})
while True:
message = await receive()
if message["type"] == "websocket.disconnect":
break
elif message["type"] == "websocket.receive":
# echo
await send({
"type": "websocket.send",
"text": message["text"],
})
return
elif scope['type'] == 'http':
is_get = scope['method'] == 'GET'
is_options = scope['method'] == 'OPTIONS'
is_homepage = scope['path'] == '/'
if is_get and is_homepage:
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
(b'content-length', b'17'),
(b'content-type', b'application/json'),
],
})
await send({
'type': 'http.response.body',
'body': b'{"hello":"world"}',
})
return
elif is_options and is_homepage:
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
(b'allow', b'GET, OPTIONS'),
(b'content-length', b'2'),
],
})
await send({
'type': 'http.response.body',
'body': b'OK',
})
return
await send({
'type': 'http.response.start',
'status': 500,
'headers': [
(b'content-type', b'text/plain'),
],
})
await send({
'type': 'http.response.body',
'body': b'Internal Server Error',
})

async def homepage(request):
return JSONResponse({'hello': 'world'})


app = Starlette(debug=True, routes=[
Route('/', homepage),
])
class ASGI2app():
def __init__(self, scope):
self.scope = scope
async def __call__(self, receive, send):
return await app(self.scope, receive, send)
Loading

0 comments on commit 84c7351

Please sign in to comment.