Skip to content

Commit

Permalink
feat: add coverage for ipv6 failure (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 9, 2023
1 parent 30b0fe7 commit 7aee8f6
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 5 deletions.
8 changes: 6 additions & 2 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from typing import List, Optional, Sequence, Tuple, Union

AddrInfoType = Tuple[
int, int, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]]
Union[int, socket.AddressFamily],
Union[int, socket.SocketKind],
int,
str,
Tuple, # type: ignore[type-arg]
]


Expand Down Expand Up @@ -40,7 +44,7 @@ async def create_connection(
* ``sockaddr``: the socket address
This method is a coroutine which will try to establish the connection
in the background. When successful, the coroutine returns a
in the background. When successful, the coroutine returns a
socket.
"""
if not (current_loop := loop):
Expand Down
116 changes: 113 additions & 3 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import socket
from test.test_asyncio import utils as test_utils
from types import ModuleType
from typing import Tuple
from unittest import mock

import pytest
Expand Down Expand Up @@ -122,7 +123,14 @@ def _socket(*args, **kw):
socket.IPPROTO_TCP,
"",
("107.6.106.82", 80),
)
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
),
]
loop = asyncio.get_running_loop()
with mock.patch.object(loop, "sock_connect", return_value=None):
Expand Down Expand Up @@ -158,7 +166,14 @@ def _socket(*args, **kw):
socket.IPPROTO_TCP,
"",
("107.6.106.82", 80),
)
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
),
]
loop = asyncio.get_running_loop()
with mock.patch.object(loop, "sock_connect", return_value=None):
Expand Down Expand Up @@ -194,8 +209,103 @@ def _socket(*args, **kw):
socket.IPPROTO_TCP,
"",
("107.6.106.82", 80),
)
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
),
]
asyncio.get_running_loop()
with pytest.raises(OSError, match=errors[0]):
await create_connection(addr_info, happy_eyeballs_delay=0.3)


@pytest.mark.asyncio
@patch_socket
async def test_create_connection_ipv6_and_ipv4_happy_eyeballs_ipv6_fails(
m_socket: ModuleType,
) -> None:
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)

def _socket(*args, **kw):
if kw["family"] == socket.AF_INET6:
raise OSError("ipv6 fail")
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv4_addr_info]
loop = asyncio.get_running_loop()
with mock.patch.object(loop, "sock_connect", return_value=None):
assert (
await create_connection(addr_info, happy_eyeballs_delay=0.3) == mock_socket
)
assert mock_socket.family == socket.AF_INET


@pytest.mark.asyncio
@patch_socket
async def test_create_connection_ipv6_and_ipv4_happy_eyeballs_ipv4_fails(
m_socket: ModuleType,
) -> None:
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)

def _socket(*args, **kw):
if kw["family"] == socket.AF_INET:
raise OSError("ipv4 fail")
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

m_socket.socket = _socket # type: ignore
ipv6_addr: Tuple[str, int, int, int] = ("dead:beef::", 80, 0, 0)
ipv6_addr_info: Tuple[int, int, int, str, Tuple[str, int, int, int]] = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
ipv6_addr,
)
ipv4_addr: Tuple[str, int] = ("107.6.106.83", 80)
ipv4_addr_info: Tuple[int, int, int, str, Tuple[str, int]] = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
ipv4_addr,
)
addr_info = [ipv6_addr_info, ipv4_addr_info]
loop = asyncio.get_running_loop()
with mock.patch.object(loop, "sock_connect", return_value=None):
assert (
await create_connection(addr_info, happy_eyeballs_delay=0.3) == mock_socket
)
assert mock_socket.family == socket.AF_INET6

0 comments on commit 7aee8f6

Please sign in to comment.