Skip to content

Commit

Permalink
Add test for inherited UNIX sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
InvalidInterrupt committed Sep 4, 2023
1 parent f635094 commit 5517728
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ console_scripts =
[options.extras_require]
tests =
django
httpunixsocketconnection
hypothesis
pytest
pytest-asyncio
Expand Down
27 changes: 20 additions & 7 deletions tests/http_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ class DaphneTestCase(unittest.TestCase):
to store/retrieve the request/response messages.
"""

_instance_endpoint_args = {}

@staticmethod
def _get_instance_raw_socket_connection(test_app, *, timeout):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
return s

@staticmethod
def _get_instance_http_connection(test_app, *, timeout):
return HTTPConnection(test_app.host, test_app.port, timeout=timeout)

### Plain HTTP helpers

def run_daphne_http(
Expand All @@ -36,13 +50,15 @@ def run_daphne_http(
and response messages.
"""
with DaphneTestingInstance(
xff=xff, request_buffer_size=request_buffer_size
xff=xff,
request_buffer_size=request_buffer_size,
**self._instance_endpoint_args,
) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
conn = self._get_instance_http_connection(test_app, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
Expand Down Expand Up @@ -74,13 +90,10 @@ def run_daphne_raw(self, data, *, responses=None, timeout=1):
Returns what Daphne sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
with DaphneTestingInstance(**self._instance_endpoint_args) as test_app:
if responses is not None:
test_app.add_send_messages(responses)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s = self._get_instance_raw_socket_connection(test_app, timeout=timeout)
s.send(data)
try:
return s.recv(1000000)
Expand Down
50 changes: 50 additions & 0 deletions tests/test_unixsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import socket
import weakref
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import skipUnless

import test_http_response
from http_base import DaphneTestCase
from httpunixsocketconnection import HTTPUnixSocketConnection

__all__ = ["UnixSocketFDDaphneTestCase", "TestInheritedUnixSocket"]


class UnixSocketFDDaphneTestCase(DaphneTestCase):
@property
def _instance_endpoint_args(self):
tmp_dir = TemporaryDirectory()
weakref.finalize(self, tmp_dir.cleanup)
sock_path = str(Path(tmp_dir.name, "test.sock"))
listen_sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
listen_sock.bind(sock_path)
listen_sock.listen()
listen_sock_fileno = os.dup(listen_sock.fileno())
os.set_inheritable(listen_sock_fileno, True)
listen_sock.close()
return {"host": None, "file_descriptor": listen_sock_fileno}

@staticmethod
def _get_instance_socket_path(test_app):
with socket.socket(fileno=os.dup(test_app.file_descriptor)) as sock:
return sock.getsockname()

@classmethod
def _get_instance_raw_socket_connection(cls, test_app, *, timeout):
socket_name = cls._get_instance_socket_path(test_app)
s = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
s.settimeout(timeout)
s.connect(socket_name)
return s

@classmethod
def _get_instance_http_connection(cls, test_app, *, timeout):
socket_name = cls._get_instance_socket_path(test_app)
return HTTPUnixSocketConnection(unix_socket=socket_name, timeout=timeout)


@skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX support not present.")
class TestInheritedUnixSocket(UnixSocketFDDaphneTestCase):
test_minimal_response = test_http_response.TestHTTPResponse.test_minimal_response

0 comments on commit 5517728

Please sign in to comment.