Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring full tests directory to typing correctly #1407

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 83 additions & 58 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import datetime
import gc
import os
import pathlib
import select
import sys
Expand All @@ -25,7 +26,6 @@
)
from gc import collect, get_referrers
from os import makedirs
from os.path import join
from socket import (
AF_INET,
AF_INET6,
Expand Down Expand Up @@ -124,6 +124,7 @@
WantWriteError,
ZeroReturnError,
_make_requires,
_NoOverlappingProtocols,
)

from .test_crypto import (
Expand Down Expand Up @@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str:
return "::1"


def join_bytes_or_unicode(prefix, suffix):
"""
Join two path components of either ``bytes`` or ``unicode``.

The return type is the same as the type of ``prefix``.
"""
# If the types are the same, nothing special is necessary.
if type(prefix) is type(suffix):
return join(prefix, suffix)

# Otherwise, coerce suffix to the type of prefix.
if isinstance(prefix, str):
return join(prefix, suffix.decode(getfilesystemencoding()))
else:
return join(prefix, suffix.encode(getfilesystemencoding()))


def verify_cb(conn, cert, errnum, depth, ok):
return ok
def verify_cb(
conn: Connection, cert: X509, errnum: int, depth: int, ok: int
) -> bool:
return bool(ok)


def socket_pair() -> tuple[socket, socket]:
Expand Down Expand Up @@ -360,7 +346,7 @@ def loopback(

def interact_in_memory(
client_conn: Connection, server_conn: Connection
) -> tuple[Connection, bytes]:
) -> tuple[Connection, bytes] | None:
"""
Try to read application bytes from each of the two `Connection` objects.
Copy bytes back and forth between their send/receive buffers for as long
Expand Down Expand Up @@ -405,6 +391,8 @@ def interact_in_memory(
wrote = True
write.bio_write(dirty)

return None


def handshake_in_memory(
client_conn: Connection, server_conn: Connection
Expand Down Expand Up @@ -1021,9 +1009,9 @@ def info(conn: Connection, where: int, ret: int) -> None:
for (conn, where, ret) in called
if not isinstance(conn, Connection)
]
assert (
[] == notConnections
), "Some info callback arguments were not Connection instances."
assert [] == notConnections, (
"Some info callback arguments were not Connection instances."
)

@pytest.mark.skipif(
not getattr(_lib, "Cryptography_HAS_KEYLOG", None),
Expand Down Expand Up @@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None:
with pytest.raises(Error):
clientContext.load_verify_locations(tmpfile)

def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
def _load_verify_directory_locations_capath(
self, capath: str | bytes
) -> None:
"""
Verify that if path to a directory containing certificate files is
passed to ``Context.load_verify_locations`` for the ``capath``
Expand All @@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
# c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
# from OpenSSL 1.0.0.
for name in [b"c7adac82.0", b"c3705638.0"]:
cafile = join_bytes_or_unicode(capath, name)
cafile: str | bytes
if isinstance(capath, str):
cafile = os.path.join(capath, name.decode())
else:
cafile = os.path.join(capath, name)
with open(cafile, "w") as fObj:
fObj.write(root_cert_pem.decode("ascii"))

Expand Down Expand Up @@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath(
"""
if pathtype == "unicode_path":
tmpfile += NON_ASCII.encode(getfilesystemencoding())

if argtype == "unicode_arg":
tmpfile = tmpfile.decode(getfilesystemencoding())
self._load_verify_directory_locations_capath(tmpfile)
self._load_verify_directory_locations_capath(
tmpfile.decode(getfilesystemencoding())
)
else:
self._load_verify_directory_locations_capath(tmpfile)

def test_load_verify_locations_wrong_args(self) -> None:
"""
Expand Down Expand Up @@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None:
serverConnection = Connection(serverContext, None)

class VerifyCallback:
def callback(self, connection: Connection, *args) -> bool:
def callback(
self,
connection: Connection,
cert: X509,
err: int,
depth: int,
ok: int,
) -> bool:
self.connection = connection
return True

Expand Down Expand Up @@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None:

clientContext = Context(TLSv1_2_METHOD)

def verify_callback(*args):
def verify_callback(
conn: Connection, cert: X509, err: int, depth: int, ok: int
) -> bool:
raise Exception("silly verify failure")

clientContext.set_verify(VERIFY_PEER, verify_callback)
Expand Down Expand Up @@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None:

for i in range(5):

def verify_callback(*args):
def verify_callback(*args: object) -> bool:
return True

serverSocket, clientSocket = socket_pair()
Expand Down Expand Up @@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None:

makedirs(certdir)

chainFile = join_bytes_or_unicode(certdir, "chain.pem")
caFile = join_bytes_or_unicode(certdir, "ca.pem")
chainFile: str | bytes
caFile: str | bytes
if isinstance(certdir, str):
chainFile = os.path.join(certdir, "chain.pem")
caFile = os.path.join(certdir, "ca.pem")
else:
chainFile = os.path.join(certdir, b"chain.pem")
caFile = os.path.join(certdir, b"ca.pem")

# Write out the chain file.
with open(chainFile, "wb") as fObj:
Expand Down Expand Up @@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover
collect()
collect()

callback = tracker()
if callback is not None:
referrers = get_referrers(callback)
callback_ref = tracker()
if callback_ref is not None:
referrers = get_referrers(callback_ref)
assert len(referrers) == 1

def test_no_servername(self) -> None:
Expand Down Expand Up @@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None:
"""
refusal_args = []

def refusal(conn: Connection, options: list[bytes]):
def refusal(
conn: Connection, options: list[bytes]
) -> _NoOverlappingProtocols:
refusal_args.append((conn, options))
return NO_OVERLAPPING_PROTOCOLS

Expand Down Expand Up @@ -2218,7 +2233,7 @@ def test_construction(self) -> None:


@pytest.fixture(params=["context", "connection"])
def ctx_or_conn(request) -> Context | Connection:
def ctx_or_conn(request: pytest.FixtureRequest) -> Context | Connection:
ctx = Context(SSLv23_METHOD)
if request.param == "context":
return ctx
Expand Down Expand Up @@ -2823,9 +2838,9 @@ def callback(
)
collect()
collect()
callback = tracker()
if callback is not None: # pragma: nocover
referrers = get_referrers(callback)
callback_ref = tracker()
if callback_ref is not None: # pragma: nocover
referrers = get_referrers(callback_ref)
assert len(referrers) == 1

def test_get_session_unconnected(self) -> None:
Expand Down Expand Up @@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None:
# meaningless.
assert sent < size

receiver, received = interact_in_memory(client, server)
result = interact_in_memory(client, server)
assert result is not None
receiver, received = result
assert receiver is server

# We can rely on all of these bytes being received at once because
Expand Down Expand Up @@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None:
called.
"""

def ocsp_callback(*args, **kwargs): # pragma: nocover
def ocsp_callback(*args: object) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(
Expand Down Expand Up @@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None:
"""
calls = []

def server_callback(*args, **kwargs):
def server_callback(*args: object, **kwargs: object) -> bytes:
return self.sample_ocsp_data

def client_callback(
Expand All @@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None:
client_calls = []
server_calls = []

def client_callback(conn, *args, **kwargs):
def client_callback(
conn: Connection, *args: object, **kwargs: object
) -> bool:
client_calls.append(conn)
return True

def server_callback(conn, *args, **kwargs):
def server_callback(
conn: Connection, *args: object, **kwargs: object
) -> bytes:
server_calls.append(conn)
return self.sample_ocsp_data

Expand All @@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None:
"""
calls = []

def server_callback(*args):
def server_callback(*args: object) -> bytes:
calls.append(args)
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> bool:
calls.append(args)
return True

Expand All @@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None:
"""
client_calls = []

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return b""

def client_callback(
Expand All @@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None:
If the client returns False from its callback, the handshake fails.
"""

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> bool:
return False

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None:
class SentinelException(Exception):
pass

def server_callback(*args):
def server_callback(*args: object) -> bytes:
return self.sample_ocsp_data

def client_callback(*args):
def client_callback(*args: object) -> typing.NoReturn:
raise SentinelException()

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None:
class SentinelException(Exception):
pass

def server_callback(*args):
def server_callback(*args: object) -> typing.NoReturn:
raise SentinelException()

def client_callback(*args): # pragma: nocover
def client_callback(
*args: object,
) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(callback=client_callback, data=None)
Expand All @@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None:
The server callback must return a bytestring, or a TypeError is thrown.
"""

def server_callback(*args):
def server_callback(*args: object) -> str:
return self.sample_ocsp_data.decode("ascii")

def client_callback(*args): # pragma: nocover
def client_callback(
*args: object,
) -> typing.NoReturn: # pragma: nocover
pytest.fail("Should not be called")

client = self._client_connection(callback=client_callback, data=None)
server = self._server_connection(callback=server_callback, data=None)
server = self._server_connection(callback=server_callback, data=None) # type: ignore[arg-type]

with pytest.raises(TypeError):
handshake_in_memory(client, server)
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ extras =
deps =
mypy
commands =
mypy src/ tests/conftest.py tests/test_crypto.py tests/test_debug.py tests/test_rand.py tests/test_util.py tests/util.py
mypy src/ tests/

[testenv:check-manifest]
deps =
Expand Down
Loading