Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: segregate connect/auth/refresh/enable device duties #233

Merged
merged 8 commits into from
Jul 24, 2024
101 changes: 51 additions & 50 deletions midealocal/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self._updates: list[Callable[[dict[str, Any]], None]] = []
self._unsupported_protocol: list[str] = []
self._is_run = False
self._available = True
self._available = False
self._appliance_query = True
self._refresh_interval = 30
self._heartbeat_interval = 10
Expand Down Expand Up @@ -190,67 +190,68 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]:
break
return result, msg

def connect(
self,
refresh_status: bool = True,
get_capabilities: bool = True,
) -> bool:
def _authenticate_refresh_enable(self) -> bool:
connected = self.connect()
if self._protocol == ProtocolVersion.V3:
self.authenticate()
self.refresh_status(wait_response=True)
self.get_capabilities()
return connected
chemelli74 marked this conversation as resolved.
Show resolved Hide resolved

def connect(self) -> bool:
"""Connect to device."""
connected = False
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
if self._protocol == ProtocolVersion.V3:
self.authenticate()
_LOGGER.debug("[%s] Authentication success", self._device_id)
if refresh_status:
self.refresh_status(wait_response=True)
if get_capabilities:
self.get_capabilities()
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
for _ in range(3):
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
connected = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add break after connected?

except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
self.enable_device(connected)
return connected

def authenticate(self) -> None:
"""Authenticate to device. V3 only."""
request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST)
_LOGGER.debug("[%s] Handshaking", self._device_id)
_LOGGER.debug("[%s] Authentication handshaking", self._device_id)
if not self._socket:
self.enable_device(False)
raise SocketException
self._socket.send(request)
response = self._socket.recv(512)
if len(response) < MIN_AUTH_RESPONSE:
self.enable_device(False)
raise AuthException
response = response[8:72]
self._security.tcp_key(response, self._key)
_LOGGER.debug("[%s] Authentication success", self._device_id)

def send_message(self, data: bytes) -> None:
"""Send message."""
Expand Down Expand Up @@ -462,6 +463,7 @@ def update_all(self, status: dict[str, Any]) -> None:

def enable_device(self, available: bool = True) -> None:
"""Enable device."""
_LOGGER.debug("[%s] Enabling device", self._device_id)
self._available = available
status = {"available": available}
self.update_all(status)
Expand Down Expand Up @@ -510,10 +512,9 @@ def _check_heartbeat(self, now: float) -> None:
def run(self) -> None:
"""Run loop."""
while self._is_run:
while self._socket is None:
if self.connect(refresh_status=True) is False:
self.close_socket()
time.sleep(5)
if not self._socket or not self.connect():
rokam marked this conversation as resolved.
Show resolved Hide resolved
raise SocketException
chemelli74 marked this conversation as resolved.
Show resolved Hide resolved
self._authenticate_refresh_enable()
timeout_counter = 0
start = time.time()
self._previous_refresh = start
Expand Down
60 changes: 16 additions & 44 deletions tests/device_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Midea Local device test."""

from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -28,7 +27,7 @@ def test_fetch_v2_message() -> None:
)


class MideaDeviceTest(IsolatedAsyncioTestCase):
class MideaDeviceTest:
"""Midea device test case."""

device: MideaDevice
Expand Down Expand Up @@ -59,55 +58,28 @@ def test_initial_attributes(self) -> None:
assert self.device.model == "test_model"
assert self.device.subtype == 1

def test_connect(self) -> None:
@pytest.mark.parametrize(
("exc", "result"),
[
(TimeoutError, False),
(OSError, False),
(AuthException, False),
(RefreshFailed, False),
(None, True),
],
)
def test_connect(self, exc: Exception, result: bool) -> None:
"""Test connect."""
with (
patch("socket.socket.connect") as connect_mock,
patch.object(
self.device,
"authenticate",
side_effect=[AuthException(), None, None],
),
patch.object(
self.device,
"refresh_status",
side_effect=[RefreshFailed(), None],
),
patch.object(
self.device,
"get_capabilities",
side_effect=[None],
),
):
connect_mock.side_effect = [
TimeoutError(),
OSError(),
None,
None,
None,
None,
]
assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is True
assert self.device.available is True
with patch("socket.socket.connect", side_effect=exc):
assert self.device.connect() is result
assert self.device.available is result

def test_connect_generic_exception(self) -> None:
"""Test connect with generic exception."""
with patch("socket.socket.connect") as connect_mock:
connect_mock.side_effect = Exception()

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

def test_authenticate(self) -> None:
Expand Down