Skip to content

Commit

Permalink
Revert "feat: segregate connect/auth/refresh/enable device duties (#233
Browse files Browse the repository at this point in the history
…)"

This reverts commit 681bd79.
  • Loading branch information
rokam authored Jul 24, 2024
1 parent 6f0a109 commit ceb499e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 73 deletions.
106 changes: 53 additions & 53 deletions midealocal/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import IntEnum, StrEnum
from typing import Any

from .exceptions import CannotConnect, SocketException
from .exceptions import SocketException
from .message import (
MessageApplianceResponse,
MessageQueryAppliance,
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
self._updates: list[Callable[[dict[str, Any]], None]] = []
self._unsupported_protocol: list[str] = []
self._is_run = False
self._available = False
self._available = True
self._appliance_query = True
self._refresh_interval = 30
self._heartbeat_interval = 10
Expand Down Expand Up @@ -190,66 +190,67 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]:
break
return result, msg

def _authenticate_refresh_capabilities(self) -> None:
if self._protocol == ProtocolVersion.V3:
self.authenticate()
self.refresh_status(wait_response=True)
self.get_capabilities()

def connect(self) -> bool:
def connect(
self,
refresh_status: bool = True,
get_capabilities: bool = True,
) -> bool:
"""Connect to device."""
connected = False
for _ in range(3):
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
if self._protocol == ProtocolVersion.V3:
self.authenticate()
_LOGGER.debug("[%s] Authentication success", self._device_id)
if refresh_status:
self.refresh_status(wait_response=True)
if get_capabilities:
self.get_capabilities()
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
self.enable_device(connected)
return connected

def authenticate(self) -> None:
"""Authenticate to device. V3 only."""
request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST)
_LOGGER.debug("[%s] Authentication handshaking", self._device_id)
_LOGGER.debug("[%s] Handshaking", self._device_id)
if not self._socket:
self.enable_device(False)
raise SocketException
self._socket.send(request)
response = self._socket.recv(512)
if len(response) < MIN_AUTH_RESPONSE:
self.enable_device(False)
raise AuthException
response = response[8:72]
self._security.tcp_key(response, self._key)
_LOGGER.debug("[%s] Authentication success", self._device_id)

def send_message(self, data: bytes) -> None:
"""Send message."""
Expand Down Expand Up @@ -461,7 +462,6 @@ def update_all(self, status: dict[str, Any]) -> None:

def enable_device(self, available: bool = True) -> None:
"""Enable device."""
_LOGGER.debug("[%s] Enabling device", self._device_id)
self._available = available
status = {"available": available}
self.update_all(status)
Expand Down Expand Up @@ -510,14 +510,14 @@ def _check_heartbeat(self, now: float) -> None:
def run(self) -> None:
"""Run loop."""
while self._is_run:
if not self.connect():
raise CannotConnect
if not self._socket:
raise SocketException
self._authenticate_refresh_capabilities()
while self._socket is None:
if self.connect(refresh_status=True) is False:
self.close_socket()
time.sleep(5)
timeout_counter = 0
start = time.time()
self._previous_refresh = self._previous_heartbeat = start
self._previous_refresh = start
self._previous_heartbeat = start
self._socket.settimeout(1)
while True:
try:
Expand Down
4 changes: 0 additions & 4 deletions midealocal/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ class CannotAuthenticate(MideaLocalError):
"""Exception raised when credentials are incorrect."""


class CannotConnect(MideaLocalError):
"""Exception raised when connection fails."""


class DataUnexpectedLength(MideaLocalError):
"""Exception raised when data length is less or more than expected."""

Expand Down
60 changes: 44 additions & 16 deletions tests/device_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Midea Local device test."""

from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -27,7 +28,7 @@ def test_fetch_v2_message() -> None:
)


class MideaDeviceTest:
class MideaDeviceTest(IsolatedAsyncioTestCase):
"""Midea device test case."""

device: MideaDevice
Expand Down Expand Up @@ -58,28 +59,55 @@ def test_initial_attributes(self) -> None:
assert self.device.model == "test_model"
assert self.device.subtype == 1

@pytest.mark.parametrize(
("exc", "result"),
[
(TimeoutError, False),
(OSError, False),
(AuthException, False),
(RefreshFailed, False),
(None, True),
],
)
def test_connect(self, exc: Exception, result: bool) -> None:
def test_connect(self) -> None:
"""Test connect."""
with patch("socket.socket.connect", side_effect=exc):
assert self.device.connect() is result
assert self.device.available is result
with (
patch("socket.socket.connect") as connect_mock,
patch.object(
self.device,
"authenticate",
side_effect=[AuthException(), None, None],
),
patch.object(
self.device,
"refresh_status",
side_effect=[RefreshFailed(), None],
),
patch.object(
self.device,
"get_capabilities",
side_effect=[None],
),
):
connect_mock.side_effect = [
TimeoutError(),
OSError(),
None,
None,
None,
None,
]
assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is True
assert self.device.available is True

def test_connect_generic_exception(self) -> None:
"""Test connect with generic exception."""
with patch("socket.socket.connect") as connect_mock:
connect_mock.side_effect = Exception()

assert self.device.connect() is False
assert self.device.connect(True, True) is False
assert self.device.available is False

def test_authenticate(self) -> None:
Expand Down

0 comments on commit ceb499e

Please sign in to comment.