From 9baefbaade98c6b76650fbfe07bf19b926578779 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 8 Jan 2025 21:57:00 -0500 Subject: [PATCH] Bring us under 100 test_ssl type-check issues (#1404) --- tests/test_ssl.py | 96 ++++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 35 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 03fce56b..dbd265b6 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -2338,7 +2338,7 @@ def test_bio_write(self) -> None: connection.bio_write(b"xy") connection.bio_write(bytearray(b"za")) with pytest.warns(DeprecationWarning): - connection.bio_write("deprecated") + connection.bio_write("deprecated") # type: ignore[arg-type] def test_get_context(self) -> None: """ @@ -2357,11 +2357,11 @@ def test_set_context_wrong_args(self) -> None: ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) with pytest.raises(TypeError): - connection.set_context(object()) + connection.set_context(object()) # type: ignore[arg-type] with pytest.raises(TypeError): - connection.set_context("hello") + connection.set_context("hello") # type: ignore[arg-type] with pytest.raises(TypeError): - connection.set_context(1) + connection.set_context(1) # type: ignore[arg-type] assert ctx is connection.get_context() def test_set_context(self) -> None: @@ -2387,12 +2387,12 @@ def test_set_tlsext_host_name_wrong_args(self) -> None: """ conn = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): - conn.set_tlsext_host_name(object()) + conn.set_tlsext_host_name(object()) # type: ignore[arg-type] with pytest.raises(TypeError): conn.set_tlsext_host_name(b"with\0null") with pytest.raises(TypeError): - conn.set_tlsext_host_name(b"example.com".decode("ascii")) + conn.set_tlsext_host_name(b"example.com".decode("ascii")) # type: ignore[arg-type] def test_pending(self) -> None: """ @@ -2498,7 +2498,7 @@ def test_shutdown_wrong_args(self) -> None: """ connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): - connection.set_shutdown(None) + connection.set_shutdown(None) # type: ignore[arg-type] def test_shutdown(self) -> None: """ @@ -2568,14 +2568,14 @@ def test_state_string(self) -> None: the `Connection`. """ server, client = socket_pair() - server = loopback_server_factory(server) - client = loopback_client_factory(client) + tls_server = loopback_server_factory(server) + tls_client = loopback_client_factory(client) - assert server.get_state_string() in [ + assert tls_server.get_state_string() in [ b"before/accept initialization", b"before SSL initialization", ] - assert client.get_state_string() in [ + assert tls_client.get_state_string() in [ b"before/connect initialization", b"before SSL initialization", ] @@ -2656,12 +2656,14 @@ def test_get_peer_cert_chain(self) -> None: interact_in_memory(client, server) chain = client.get_peer_cert_chain() + assert chain is not None assert len(chain) == 3 assert "Server Certificate" == chain[0].get_subject().CN assert "Intermediate Certificate" == chain[1].get_subject().CN assert "Authority Certificate" == chain[2].get_subject().CN cryptography_chain = client.get_peer_cert_chain(as_cryptography=True) + assert cryptography_chain is not None assert len(cryptography_chain) == 3 assert ( cryptography_chain[0].subject.rfc4514_string() @@ -2710,7 +2712,9 @@ def test_get_verified_chain(self) -> None: clientContext = Context(SSLv23_METHOD) # cacert is self-signed so the client must trust it for verification # to succeed. - clientContext.get_cert_store().add_cert(cacert) + cert_store = clientContext.get_cert_store() + assert cert_store is not None + cert_store.add_cert(cacert) clientContext.set_verify(VERIFY_PEER, verify_cb) client = Connection(clientContext, None) client.set_connect_state() @@ -2774,10 +2778,10 @@ def test_set_verify_overrides_context(self) -> None: assert conn.get_verify_mode() == VERIFY_NONE with pytest.raises(TypeError): - conn.set_verify(None) + conn.set_verify(None) # type: ignore[arg-type] with pytest.raises(TypeError): - conn.set_verify(VERIFY_PEER, "not a callable") + conn.set_verify(VERIFY_PEER, "not a callable") # type: ignore[arg-type] def test_set_verify_callback_reference(self) -> None: """ @@ -2801,7 +2805,9 @@ def callback(conn, cert, errnum, depth, ok): # pragma: no cover collect() assert tracker() - conn.set_verify(VERIFY_PEER, lambda conn, cert, errnum, depth, ok: ok) + conn.set_verify( + VERIFY_PEER, lambda conn, cert, errnum, depth, ok: bool(ok) + ) collect() collect() callback = tracker() @@ -2846,11 +2852,11 @@ def test_set_session_wrong_args(self) -> None: ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) with pytest.raises(TypeError): - connection.set_session(123) + connection.set_session(123) # type: ignore[arg-type] with pytest.raises(TypeError): - connection.set_session("hello") + connection.set_session("hello") # type: ignore[arg-type] with pytest.raises(TypeError): - connection.set_session(object()) + connection.set_session(object()) # type: ignore[arg-type] def test_client_set_session(self) -> None: """ @@ -2872,6 +2878,7 @@ def makeServer(socket): originalServer, originalClient = loopback(server_factory=makeServer) originalSession = originalClient.get_session() + assert originalSession is not None def makeClient(socket): client = loopback_client_factory(socket) @@ -2920,6 +2927,7 @@ def makeOriginalClient(socket): server_factory=makeServer, client_factory=makeOriginalClient ) originalSession = originalClient.get_session() + assert originalSession is not None def makeClient(socket): # Intentionally use a different, incompatible method here. @@ -2993,8 +3001,9 @@ def test_get_finished(self) -> None: """ server, _ = loopback() - assert server.get_finished() is not None - assert len(server.get_finished()) > 0 + finished = server.get_finished() + assert finished is not None + assert len(finished) > 0 def test_get_peer_finished(self) -> None: """ @@ -3004,8 +3013,9 @@ def test_get_peer_finished(self) -> None: """ server, _ = loopback() - assert server.get_peer_finished() is not None - assert len(server.get_peer_finished()) > 0 + finished = server.get_peer_finished() + assert finished is not None + assert len(finished) > 0 def test_tls_finished_message_symmetry(self) -> None: """ @@ -3198,9 +3208,9 @@ def test_wrong_args(self) -> None: """ connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): - connection.send(object()) + connection.send(object()) # type: ignore[arg-type] with pytest.raises(TypeError): - connection.send([1, 2, 3]) + connection.send([1, 2, 3]) # type: ignore[arg-type] def test_short_bytes(self) -> None: """ @@ -3219,7 +3229,7 @@ def test_text(self) -> None: """ server, client = loopback() with pytest.warns(DeprecationWarning) as w: - count = server.send(b"xy".decode("ascii")) + count = server.send(b"xy".decode("ascii")) # type: ignore[arg-type] assert ( f"{WARNING_TYPE_EXPECTED} for buf is no longer accepted, " f"use bytes" @@ -3407,9 +3417,9 @@ def test_wrong_args(self) -> None: """ connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): - connection.sendall(object()) + connection.sendall(object()) # type: ignore[arg-type] with pytest.raises(TypeError): - connection.sendall([1, 2, 3]) + connection.sendall([1, 2, 3]) # type: ignore[arg-type] def test_short(self) -> None: """ @@ -3427,7 +3437,7 @@ def test_text(self) -> None: """ server, client = loopback() with pytest.warns(DeprecationWarning) as w: - server.sendall(b"x".decode("ascii")) + server.sendall(b"x".decode("ascii")) # type: ignore[arg-type] assert ( f"{WARNING_TYPE_EXPECTED} for buf is no longer accepted, " f"use bytes" @@ -3656,7 +3666,7 @@ class TestMemoryBIO: Tests for `OpenSSL.SSL.Connection` using a memory BIO. """ - def _server(self, sock): + def _server(self, sock: socket | None) -> Connection: """ Create a new server-side SSL `Connection` object wrapped around `sock`. """ @@ -3669,6 +3679,7 @@ def _server(self, sock): verify_cb, ) server_store = server_ctx.get_cert_store() + assert server_store is not None server_ctx.use_privatekey( load_privatekey(FILETYPE_PEM, server_key_pem) ) @@ -3683,7 +3694,7 @@ def _server(self, sock): server_conn.set_accept_state() return server_conn - def _client(self, sock): + def _client(self, sock: socket | None) -> Connection: """ Create a new client-side SSL `Connection` object wrapped around `sock`. """ @@ -3696,6 +3707,7 @@ def _client(self, sock): verify_cb, ) client_store = client_ctx.get_cert_store() + assert client_store is not None client_ctx.use_privatekey( load_privatekey(FILETYPE_PEM, client_key_pem) ) @@ -4154,6 +4166,9 @@ def inner(): # pragma: nocover assert "Error text" in str(e.value) +T = typing.TypeVar("T") + + class TestOCSP: """ Tests for PyOpenSSL's OCSP stapling support. @@ -4161,7 +4176,12 @@ class TestOCSP: sample_ocsp_data = b"this is totally ocsp data" - def _client_connection(self, callback, data, request_ocsp=True): + def _client_connection( + self, + callback: typing.Callable[[Connection, bytes, T | None], bool], + data: T | None, + request_ocsp=True, + ) -> Connection: """ Builds a client connection suitable for using OCSP. @@ -4181,7 +4201,11 @@ def _client_connection(self, callback, data, request_ocsp=True): client.set_connect_state() return client - def _server_connection(self, callback, data): + def _server_connection( + self, + callback: typing.Callable[[Connection, T | None], bytes], + data: T | None, + ) -> Connection: """ Builds a server connection suitable for using OCSP. @@ -4473,7 +4497,7 @@ class TestDTLS: # Arbitrary number larger than any conceivable handshake volley. LARGE_BUFFER = 65536 - def _test_handshake_and_data(self, srtp_profile): + def _test_handshake_and_data(self, srtp_profile: bytes | None) -> None: s_ctx = Context(DTLS_METHOD) def generate_cookie(ssl): @@ -4506,7 +4530,9 @@ def verify_cookie(ssl, cookie): latest_client_hello = None - def pump_membio(label, source, sink): + def pump_membio( + label: str, source: Connection, sink: Connection + ) -> bool: try: chunk = source.bio_read(self.LARGE_BUFFER) except WantReadError: @@ -4597,7 +4623,7 @@ def test_it_works_at_all(self) -> None: def test_it_works_with_srtp(self) -> None: self._test_handshake_and_data(srtp_profile=b"SRTP_AES128_CM_SHA1_80") - def test_timeout(self, monkeypatch) -> None: + def test_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None: c_ctx = Context(DTLS_METHOD) c = Connection(c_ctx)