Skip to content

Commit

Permalink
Bring full tests directory to typing correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Jan 9, 2025
1 parent 9422c36 commit 2945aae
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 56 deletions.
135 changes: 80 additions & 55 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import datetime
import gc
import os
import pathlib
import select
import sys
Expand All @@ -25,7 +26,6 @@
)
from gc import collect, get_referrers
from os import makedirs
from os.path import join
from socket import (
AF_INET,
AF_INET6,
Expand Down Expand Up @@ -124,6 +124,7 @@
WantWriteError,
ZeroReturnError,
_make_requires,
_NoOverlappingProtocols,
)

from .test_crypto import (
Expand Down Expand Up @@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str:
return "::1"


def join_bytes_or_unicode(prefix, suffix):
"""
Join two path components of either ``bytes`` or ``unicode``.
The return type is the same as the type of ``prefix``.
"""
# If the types are the same, nothing special is necessary.
if type(prefix) is type(suffix):
return join(prefix, suffix)

# Otherwise, coerce suffix to the type of prefix.
if isinstance(prefix, str):
return join(prefix, suffix.decode(getfilesystemencoding()))
else:
return join(prefix, suffix.encode(getfilesystemencoding()))


def verify_cb(conn, cert, errnum, depth, ok):
return ok
def verify_cb(
conn: Connection, cert: X509, errnum: int, depth: int, ok: int
) -> bool:
return bool(ok)


def socket_pair() -> tuple[socket, socket]:
Expand Down Expand Up @@ -360,7 +346,7 @@ def loopback(

def interact_in_memory(
client_conn: Connection, server_conn: Connection
) -> tuple[Connection, bytes]:
) -> tuple[Connection, bytes] | None:
"""
Try to read application bytes from each of the two `Connection` objects.
Copy bytes back and forth between their send/receive buffers for as long
Expand Down Expand Up @@ -405,6 +391,8 @@ def interact_in_memory(
wrote = True
write.bio_write(dirty)

return None


def handshake_in_memory(
client_conn: Connection, server_conn: Connection
Expand Down Expand Up @@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None:
with pytest.raises(Error):
clientContext.load_verify_locations(tmpfile)

def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
def _load_verify_directory_locations_capath(
self, capath: str | bytes
) -> None:
"""
Verify that if path to a directory containing certificate files is
passed to ``Context.load_verify_locations`` for the ``capath``
Expand All @@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
# c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
# from OpenSSL 1.0.0.
for name in [b"c7adac82.0", b"c3705638.0"]:
cafile = join_bytes_or_unicode(capath, name)
cafile: str | bytes
if isinstance(capath, str):
cafile = os.path.join(capath, name.decode())
else:
cafile = os.path.join(capath, name)
with open(cafile, "w") as fObj:
fObj.write(root_cert_pem.decode("ascii"))

Expand Down Expand Up @@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath(
"""
if pathtype == "unicode_path":
tmpfile += NON_ASCII.encode(getfilesystemencoding())

if argtype == "unicode_arg":
tmpfile = tmpfile.decode(getfilesystemencoding())
self._load_verify_directory_locations_capath(tmpfile)
self._load_verify_directory_locations_capath(
tmpfile.decode(getfilesystemencoding())
)
else:
self._load_verify_directory_locations_capath(tmpfile)

def test_load_verify_locations_wrong_args(self) -> None:
"""
Expand Down Expand Up @@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None:
serverConnection = Connection(serverContext, None)

class VerifyCallback:
def callback(self, connection: Connection, *args) -> bool:
def callback(
self,
connection: Connection,
cert: X509,
err: int,
depth: int,
ok: int,
) -> bool:
self.connection = connection
return True

Expand Down Expand Up @@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None:

clientContext = Context(TLSv1_2_METHOD)

def verify_callback(*args):
def verify_callback(
conn: Connection, cert: X509, err: int, depth: int, ok: int
) -> bool:
raise Exception("silly verify failure")

clientContext.set_verify(VERIFY_PEER, verify_callback)
Expand Down Expand Up @@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None:

for i in range(5):

def verify_callback(*args):
def verify_callback(*args: object) -> bool:
return True

serverSocket, clientSocket = socket_pair()
Expand Down Expand Up @@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None:

makedirs(certdir)

chainFile = join_bytes_or_unicode(certdir, "chain.pem")
caFile = join_bytes_or_unicode(certdir, "ca.pem")
chainFile: str | bytes
caFile: str | bytes
if isinstance(certdir, str):
chainFile = os.path.join(certdir, "chain.pem")
caFile = os.path.join(certdir, "ca.pem")
else:
chainFile = os.path.join(certdir, b"chain.pem")
caFile = os.path.join(certdir, b"ca.pem")

# Write out the chain file.
with open(chainFile, "wb") as fObj:
Expand Down Expand Up @@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover
collect()
collect()

callback = tracker()
if callback is not None:
referrers = get_referrers(callback)
callback_ref = tracker()
if callback_ref is not None:
referrers = get_referrers(callback_ref)
assert len(referrers) == 1

def test_no_servername(self) -> None:
Expand Down Expand Up @@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None:
"""
refusal_args = []

def refusal(conn: Connection, options: list[bytes]):
def refusal(
conn: Connection, options: list[bytes]
) -> _NoOverlappingProtocols:
refusal_args.append((conn, options))
return NO_OVERLAPPING_PROTOCOLS

Expand Down Expand Up @@ -2218,7 +2233,7 @@ def test_construction(self) -> None:


@pytest.fixture(params=["context", "connection"])
def ctx_or_conn(request) -> Context | Connection:
def ctx_or_conn(request: pytest.FixtureRequest) -> Context | Connection:
ctx = Context(SSLv23_METHOD)
if request.param == "context":
return ctx
Expand Down Expand Up @@ -2823,9 +2838,9 @@ def callback(
)
collect()
collect()
callback = tracker()
if callback is not None: # pragma: nocover
referrers = get_referrers(callback)
callback_ref = tracker()
if callback_ref is not None: # pragma: nocover
referrers = get_referrers(callback_ref)
assert len(referrers) == 1

def test_get_session_unconnected(self) -> None:
Expand Down Expand Up @@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None:
# meaningless.
assert sent < size

receiver, received = interact_in_memory(client, server)
result = interact_in_memory(client, server)
assert result is not None
receiver, received = result
assert receiver is server

# We can rely on all of these bytes being received at once because
Expand Down Expand Up @@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None:
called.
"""

def ocsp_callback(*args, **kwargs): # pragma: nocover
def ocsp_callback(*args: object) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(
Expand Down Expand Up @@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None:
"""
calls = []

def server_callback(*args, **kwargs):
def server_callback(*args: object, **kwargs: object) -> bytes:
return self.sample_ocsp_data

def client_callback(
Expand All @@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None:
client_calls = []
server_calls = []

def client_callback(conn, *args, **kwargs):
def client_callback(
conn: Connection, *args: object, **kwargs: object
) -> bool:
client_calls.append(conn)
return True

def server_callback(conn, *args, **kwargs):
def server_callback(
conn: Connection, *args: object, **kwargs: object
) -> bytes:
server_calls.append(conn)
return self.sample_ocsp_data

Expand All @@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None:
"""
calls = []

def server_callback(*args):
def server_callback(*args: object) -> bytes:
calls.append(args)
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> bool:
calls.append(args)
return True

Expand All @@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None:
"""
client_calls = []

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return b""

def client_callback(
Expand All @@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None:
If the client returns False from its callback, the handshake fails.
"""

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> bool:
return False

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None:
class SentinelException(Exception):
pass

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> typing.NoReturn:
raise SentinelException()

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None:
class SentinelException(Exception):
pass

def server_callback(*args):
def server_callback(*args: object) -> typing.NoReturn:
raise SentinelException()

def client_callback(*args): # pragma: nocover
def client_callback(
*args: object,
) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None:
The server callback must return a bytestring, or a TypeError is thrown.
"""

def server_callback(*args):
def server_callback(*args: object) -> str:
return self.sample_ocsp_data.decode("ascii")

def client_callback(*args): # pragma: nocover
def client_callback(
*args: object,
) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(callback=client_callback, data=None)
server = self._server_connection(callback=server_callback, data=None)
server = self._server_connection(callback=server_callback, data=None) # type: ignore[arg-type]

with pytest.raises(TypeError):
handshake_in_memory(client, server)
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ extras =
deps =
mypy
commands =
mypy src/ tests/conftest.py tests/test_crypto.py tests/test_debug.py tests/test_rand.py tests/test_util.py tests/util.py
mypy src/ tests/

[testenv:check-manifest]
deps =
Expand Down

0 comments on commit 2945aae

Please sign in to comment.