Skip to content

Commit

Permalink
Bring us under 100 test_ssl type-check issues (#1404)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Jan 9, 2025
1 parent 317c7fa commit 9baefba
Showing 1 changed file with 61 additions and 35 deletions.
96 changes: 61 additions & 35 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ def test_bio_write(self) -> None:
connection.bio_write(b"xy")
connection.bio_write(bytearray(b"za"))
with pytest.warns(DeprecationWarning):
connection.bio_write("deprecated")
connection.bio_write("deprecated") # type: ignore[arg-type]

def test_get_context(self) -> None:
"""
Expand All @@ -2357,11 +2357,11 @@ def test_set_context_wrong_args(self) -> None:
ctx = Context(SSLv23_METHOD)
connection = Connection(ctx, None)
with pytest.raises(TypeError):
connection.set_context(object())
connection.set_context(object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.set_context("hello")
connection.set_context("hello") # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.set_context(1)
connection.set_context(1) # type: ignore[arg-type]
assert ctx is connection.get_context()

def test_set_context(self) -> None:
Expand All @@ -2387,12 +2387,12 @@ def test_set_tlsext_host_name_wrong_args(self) -> None:
"""
conn = Connection(Context(SSLv23_METHOD), None)
with pytest.raises(TypeError):
conn.set_tlsext_host_name(object())
conn.set_tlsext_host_name(object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
conn.set_tlsext_host_name(b"with\0null")

with pytest.raises(TypeError):
conn.set_tlsext_host_name(b"example.com".decode("ascii"))
conn.set_tlsext_host_name(b"example.com".decode("ascii")) # type: ignore[arg-type]

def test_pending(self) -> None:
"""
Expand Down Expand Up @@ -2498,7 +2498,7 @@ def test_shutdown_wrong_args(self) -> None:
"""
connection = Connection(Context(SSLv23_METHOD), None)
with pytest.raises(TypeError):
connection.set_shutdown(None)
connection.set_shutdown(None) # type: ignore[arg-type]

def test_shutdown(self) -> None:
"""
Expand Down Expand Up @@ -2568,14 +2568,14 @@ def test_state_string(self) -> None:
the `Connection`.
"""
server, client = socket_pair()
server = loopback_server_factory(server)
client = loopback_client_factory(client)
tls_server = loopback_server_factory(server)
tls_client = loopback_client_factory(client)

assert server.get_state_string() in [
assert tls_server.get_state_string() in [
b"before/accept initialization",
b"before SSL initialization",
]
assert client.get_state_string() in [
assert tls_client.get_state_string() in [
b"before/connect initialization",
b"before SSL initialization",
]
Expand Down Expand Up @@ -2656,12 +2656,14 @@ def test_get_peer_cert_chain(self) -> None:
interact_in_memory(client, server)

chain = client.get_peer_cert_chain()
assert chain is not None
assert len(chain) == 3
assert "Server Certificate" == chain[0].get_subject().CN
assert "Intermediate Certificate" == chain[1].get_subject().CN
assert "Authority Certificate" == chain[2].get_subject().CN

cryptography_chain = client.get_peer_cert_chain(as_cryptography=True)
assert cryptography_chain is not None
assert len(cryptography_chain) == 3
assert (
cryptography_chain[0].subject.rfc4514_string()
Expand Down Expand Up @@ -2710,7 +2712,9 @@ def test_get_verified_chain(self) -> None:
clientContext = Context(SSLv23_METHOD)
# cacert is self-signed so the client must trust it for verification
# to succeed.
clientContext.get_cert_store().add_cert(cacert)
cert_store = clientContext.get_cert_store()
assert cert_store is not None
cert_store.add_cert(cacert)
clientContext.set_verify(VERIFY_PEER, verify_cb)
client = Connection(clientContext, None)
client.set_connect_state()
Expand Down Expand Up @@ -2774,10 +2778,10 @@ def test_set_verify_overrides_context(self) -> None:
assert conn.get_verify_mode() == VERIFY_NONE

with pytest.raises(TypeError):
conn.set_verify(None)
conn.set_verify(None) # type: ignore[arg-type]

with pytest.raises(TypeError):
conn.set_verify(VERIFY_PEER, "not a callable")
conn.set_verify(VERIFY_PEER, "not a callable") # type: ignore[arg-type]

def test_set_verify_callback_reference(self) -> None:
"""
Expand All @@ -2801,7 +2805,9 @@ def callback(conn, cert, errnum, depth, ok): # pragma: no cover
collect()
assert tracker()

conn.set_verify(VERIFY_PEER, lambda conn, cert, errnum, depth, ok: ok)
conn.set_verify(
VERIFY_PEER, lambda conn, cert, errnum, depth, ok: bool(ok)
)
collect()
collect()
callback = tracker()
Expand Down Expand Up @@ -2846,11 +2852,11 @@ def test_set_session_wrong_args(self) -> None:
ctx = Context(SSLv23_METHOD)
connection = Connection(ctx, None)
with pytest.raises(TypeError):
connection.set_session(123)
connection.set_session(123) # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.set_session("hello")
connection.set_session("hello") # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.set_session(object())
connection.set_session(object()) # type: ignore[arg-type]

def test_client_set_session(self) -> None:
"""
Expand All @@ -2872,6 +2878,7 @@ def makeServer(socket):

originalServer, originalClient = loopback(server_factory=makeServer)
originalSession = originalClient.get_session()
assert originalSession is not None

def makeClient(socket):
client = loopback_client_factory(socket)
Expand Down Expand Up @@ -2920,6 +2927,7 @@ def makeOriginalClient(socket):
server_factory=makeServer, client_factory=makeOriginalClient
)
originalSession = originalClient.get_session()
assert originalSession is not None

def makeClient(socket):
# Intentionally use a different, incompatible method here.
Expand Down Expand Up @@ -2993,8 +3001,9 @@ def test_get_finished(self) -> None:
"""
server, _ = loopback()

assert server.get_finished() is not None
assert len(server.get_finished()) > 0
finished = server.get_finished()
assert finished is not None
assert len(finished) > 0

def test_get_peer_finished(self) -> None:
"""
Expand All @@ -3004,8 +3013,9 @@ def test_get_peer_finished(self) -> None:
"""
server, _ = loopback()

assert server.get_peer_finished() is not None
assert len(server.get_peer_finished()) > 0
finished = server.get_peer_finished()
assert finished is not None
assert len(finished) > 0

def test_tls_finished_message_symmetry(self) -> None:
"""
Expand Down Expand Up @@ -3198,9 +3208,9 @@ def test_wrong_args(self) -> None:
"""
connection = Connection(Context(SSLv23_METHOD), None)
with pytest.raises(TypeError):
connection.send(object())
connection.send(object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.send([1, 2, 3])
connection.send([1, 2, 3]) # type: ignore[arg-type]

def test_short_bytes(self) -> None:
"""
Expand All @@ -3219,7 +3229,7 @@ def test_text(self) -> None:
"""
server, client = loopback()
with pytest.warns(DeprecationWarning) as w:
count = server.send(b"xy".decode("ascii"))
count = server.send(b"xy".decode("ascii")) # type: ignore[arg-type]
assert (
f"{WARNING_TYPE_EXPECTED} for buf is no longer accepted, "
f"use bytes"
Expand Down Expand Up @@ -3407,9 +3417,9 @@ def test_wrong_args(self) -> None:
"""
connection = Connection(Context(SSLv23_METHOD), None)
with pytest.raises(TypeError):
connection.sendall(object())
connection.sendall(object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
connection.sendall([1, 2, 3])
connection.sendall([1, 2, 3]) # type: ignore[arg-type]

def test_short(self) -> None:
"""
Expand All @@ -3427,7 +3437,7 @@ def test_text(self) -> None:
"""
server, client = loopback()
with pytest.warns(DeprecationWarning) as w:
server.sendall(b"x".decode("ascii"))
server.sendall(b"x".decode("ascii")) # type: ignore[arg-type]
assert (
f"{WARNING_TYPE_EXPECTED} for buf is no longer accepted, "
f"use bytes"
Expand Down Expand Up @@ -3656,7 +3666,7 @@ class TestMemoryBIO:
Tests for `OpenSSL.SSL.Connection` using a memory BIO.
"""

def _server(self, sock):
def _server(self, sock: socket | None) -> Connection:
"""
Create a new server-side SSL `Connection` object wrapped around `sock`.
"""
Expand All @@ -3669,6 +3679,7 @@ def _server(self, sock):
verify_cb,
)
server_store = server_ctx.get_cert_store()
assert server_store is not None
server_ctx.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem)
)
Expand All @@ -3683,7 +3694,7 @@ def _server(self, sock):
server_conn.set_accept_state()
return server_conn

def _client(self, sock):
def _client(self, sock: socket | None) -> Connection:
"""
Create a new client-side SSL `Connection` object wrapped around `sock`.
"""
Expand All @@ -3696,6 +3707,7 @@ def _client(self, sock):
verify_cb,
)
client_store = client_ctx.get_cert_store()
assert client_store is not None
client_ctx.use_privatekey(
load_privatekey(FILETYPE_PEM, client_key_pem)
)
Expand Down Expand Up @@ -4154,14 +4166,22 @@ def inner(): # pragma: nocover
assert "Error text" in str(e.value)


T = typing.TypeVar("T")


class TestOCSP:
"""
Tests for PyOpenSSL's OCSP stapling support.
"""

sample_ocsp_data = b"this is totally ocsp data"

def _client_connection(self, callback, data, request_ocsp=True):
def _client_connection(
self,
callback: typing.Callable[[Connection, bytes, T | None], bool],
data: T | None,
request_ocsp=True,
) -> Connection:
"""
Builds a client connection suitable for using OCSP.
Expand All @@ -4181,7 +4201,11 @@ def _client_connection(self, callback, data, request_ocsp=True):
client.set_connect_state()
return client

def _server_connection(self, callback, data):
def _server_connection(
self,
callback: typing.Callable[[Connection, T | None], bytes],
data: T | None,
) -> Connection:
"""
Builds a server connection suitable for using OCSP.
Expand Down Expand Up @@ -4473,7 +4497,7 @@ class TestDTLS:
# Arbitrary number larger than any conceivable handshake volley.
LARGE_BUFFER = 65536

def _test_handshake_and_data(self, srtp_profile):
def _test_handshake_and_data(self, srtp_profile: bytes | None) -> None:
s_ctx = Context(DTLS_METHOD)

def generate_cookie(ssl):
Expand Down Expand Up @@ -4506,7 +4530,9 @@ def verify_cookie(ssl, cookie):

latest_client_hello = None

def pump_membio(label, source, sink):
def pump_membio(
label: str, source: Connection, sink: Connection
) -> bool:
try:
chunk = source.bio_read(self.LARGE_BUFFER)
except WantReadError:
Expand Down Expand Up @@ -4597,7 +4623,7 @@ def test_it_works_at_all(self) -> None:
def test_it_works_with_srtp(self) -> None:
self._test_handshake_and_data(srtp_profile=b"SRTP_AES128_CM_SHA1_80")

def test_timeout(self, monkeypatch) -> None:
def test_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None:
c_ctx = Context(DTLS_METHOD)
c = Connection(c_ctx)

Expand Down

0 comments on commit 9baefba

Please sign in to comment.