From ae8f1ea98b6c7bb1ef8e3c7bb551a5c31c763b4f Mon Sep 17 00:00:00 2001 From: Simone Chemelli Date: Fri, 19 Jul 2024 14:28:30 +0000 Subject: [PATCH 1/6] feat: segregate connect/auth/refresh/enable device duties --- midealocal/device.py | 99 ++++++++++++++++++++++---------------------- tests/device_test.py | 12 +++--- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index 10d92ded..982de80e 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -190,59 +190,57 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]: break return result, msg - def connect( - self, - refresh_status: bool = True, - get_capabilities: bool = True, - ) -> bool: + def _authenticate_refresh_enable(self) -> bool: + connected = self.connect() + if self._protocol == ProtocolVersion.V3: + self.authenticate() + self.refresh_status(wait_response=True) + self.get_capabilities() + self.enable_device(connected) + return connected + + def connect(self) -> bool: """Connect to device.""" connected = False - try: - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.settimeout(10) - _LOGGER.debug( - "[%s] Connecting to %s:%s", - self._device_id, - self._ip_address, - self._port, - ) - self._socket.connect((self._ip_address, self._port)) - _LOGGER.debug("[%s] Connected", self._device_id) - if self._protocol == ProtocolVersion.V3: - self.authenticate() - _LOGGER.debug("[%s] Authentication success", self._device_id) - if refresh_status: - self.refresh_status(wait_response=True) - if get_capabilities: - self.get_capabilities() - connected = True - except TimeoutError: - _LOGGER.debug("[%s] Connection timed out", self._device_id) - except OSError: - _LOGGER.debug("[%s] Connection error", self._device_id) - except AuthException: - _LOGGER.debug("[%s] Authentication failed", self._device_id) - except RefreshFailed: - _LOGGER.debug("[%s] Refresh status is timed out", self._device_id) - except Exception as e: - file = None - lineno = None - if e.__traceback__: - file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 - lineno = e.__traceback__.tb_lineno - _LOGGER.exception( - "[%s] Unknown error : %s, %s", - self._device_id, - file, - lineno, - ) - self.enable_device(connected) + for _ in range(3): + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.settimeout(10) + _LOGGER.debug( + "[%s] Connecting to %s:%s", + self._device_id, + self._ip_address, + self._port, + ) + self._socket.connect((self._ip_address, self._port)) + _LOGGER.debug("[%s] Connected", self._device_id) + connected = True + except TimeoutError: + _LOGGER.debug("[%s] Connection timed out", self._device_id) + except OSError: + _LOGGER.debug("[%s] Connection error", self._device_id) + except AuthException: + _LOGGER.debug("[%s] Authentication failed", self._device_id) + except RefreshFailed: + _LOGGER.debug("[%s] Refresh status is timed out", self._device_id) + except Exception as e: + file = None + lineno = None + if e.__traceback__: + file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 + lineno = e.__traceback__.tb_lineno + _LOGGER.exception( + "[%s] Unknown error : %s, %s", + self._device_id, + file, + lineno, + ) return connected def authenticate(self) -> None: """Authenticate to device. V3 only.""" request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST) - _LOGGER.debug("[%s] Handshaking", self._device_id) + _LOGGER.debug("[%s] Authentication handshaking", self._device_id) if not self._socket: raise SocketException self._socket.send(request) @@ -251,6 +249,7 @@ def authenticate(self) -> None: raise AuthException response = response[8:72] self._security.tcp_key(response, self._key) + _LOGGER.debug("[%s] Authentication success", self._device_id) def send_message(self, data: bytes) -> None: """Send message.""" @@ -462,6 +461,7 @@ def update_all(self, status: dict[str, Any]) -> None: def enable_device(self, available: bool = True) -> None: """Enable device.""" + _LOGGER.debug("[%s] Enabling device", self._device_id) self._available = available status = {"available": available} self.update_all(status) @@ -510,10 +510,9 @@ def _check_heartbeat(self, now: float) -> None: def run(self) -> None: """Run loop.""" while self._is_run: - while self._socket is None: - if self.connect(refresh_status=True) is False: - self.close_socket() - time.sleep(5) + if not self._socket or not self.connect(): + raise SocketException + self._authenticate_refresh_enable() timeout_counter = 0 start = time.time() self._previous_refresh = start diff --git a/tests/device_test.py b/tests/device_test.py index 491813be..21714ce6 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -87,19 +87,19 @@ def test_connect(self) -> None: None, None, ] - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False - assert self.device.connect(True, True) is True + assert self.device.connect() is True assert self.device.available is True def test_connect_generic_exception(self) -> None: @@ -107,7 +107,7 @@ def test_connect_generic_exception(self) -> None: with patch("socket.socket.connect") as connect_mock: connect_mock.side_effect = Exception() - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False def test_authenticate(self) -> None: From 05be7b570a59a0041c8d51d63b6158692fc0f644 Mon Sep 17 00:00:00 2001 From: Simone Chemelli Date: Mon, 22 Jul 2024 14:22:52 +0000 Subject: [PATCH 2/6] chore: fix tests --- midealocal/device.py | 6 +++-- tests/device_test.py | 58 ++++++++++++-------------------------------- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index 982de80e..b50a2ca5 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -140,7 +140,7 @@ def __init__( self._updates: list[Callable[[dict[str, Any]], None]] = [] self._unsupported_protocol: list[str] = [] self._is_run = False - self._available = True + self._available = False self._appliance_query = True self._refresh_interval = 30 self._heartbeat_interval = 10 @@ -196,7 +196,6 @@ def _authenticate_refresh_enable(self) -> bool: self.authenticate() self.refresh_status(wait_response=True) self.get_capabilities() - self.enable_device(connected) return connected def connect(self) -> bool: @@ -235,6 +234,7 @@ def connect(self) -> bool: file, lineno, ) + self.enable_device(connected) return connected def authenticate(self) -> None: @@ -242,10 +242,12 @@ def authenticate(self) -> None: request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST) _LOGGER.debug("[%s] Authentication handshaking", self._device_id) if not self._socket: + self.enable_device(False) raise SocketException self._socket.send(request) response = self._socket.recv(512) if len(response) < MIN_AUTH_RESPONSE: + self.enable_device(False) raise AuthException response = response[8:72] self._security.tcp_key(response, self._key) diff --git a/tests/device_test.py b/tests/device_test.py index 21714ce6..88b8ba03 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -1,6 +1,5 @@ """Midea Local device test.""" -from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch import pytest @@ -28,7 +27,7 @@ def test_fetch_v2_message() -> None: ) -class MideaDeviceTest(IsolatedAsyncioTestCase): +class MideaDeviceTest: """Midea device test case.""" device: MideaDevice @@ -59,48 +58,21 @@ def test_initial_attributes(self) -> None: assert self.device.model == "test_model" assert self.device.subtype == 1 - def test_connect(self) -> None: + @pytest.mark.parametrize( + ("exc", "result"), + [ + (TimeoutError, False), + (OSError, False), + (AuthException, False), + (RefreshFailed, False), + (None, True), + ], + ) + def test_connect(self, exc: Exception, result: bool) -> None: """Test connect.""" - with ( - patch("socket.socket.connect") as connect_mock, - patch.object( - self.device, - "authenticate", - side_effect=[AuthException(), None, None], - ), - patch.object( - self.device, - "refresh_status", - side_effect=[RefreshFailed(), None], - ), - patch.object( - self.device, - "get_capabilities", - side_effect=[None], - ), - ): - connect_mock.side_effect = [ - TimeoutError(), - OSError(), - None, - None, - None, - None, - ] - assert self.device.connect() is False - assert self.device.available is False - - assert self.device.connect() is False - assert self.device.available is False - - assert self.device.connect() is False - assert self.device.available is False - - assert self.device.connect() is False - assert self.device.available is False - - assert self.device.connect() is True - assert self.device.available is True + with patch("socket.socket.connect", side_effect=exc): + assert self.device.connect() is result + assert self.device.available is result def test_connect_generic_exception(self) -> None: """Test connect with generic exception.""" From 57a2bb19da1de5def5cb2ad06ec14b336fba16a5 Mon Sep 17 00:00:00 2001 From: Simone Chemelli Date: Tue, 23 Jul 2024 17:29:57 +0000 Subject: [PATCH 3/6] chore: close socket what connect fails --- midealocal/device.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/midealocal/device.py b/midealocal/device.py index b50a2ca5..a5a3574e 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -234,6 +234,9 @@ def connect(self) -> bool: file, lineno, ) + finally: + if self._socket: + self._socket.close() self.enable_device(connected) return connected From a96df37131fd93ca770774e85d7f4f4a73f2a998 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:17:37 -0300 Subject: [PATCH 4/6] chore(ci): pre-commit autoupdate (#237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.2...v0.5.4 - https://github.com/commitizen-tools/commitizen/compare/v3.27.0...v3.28.0 Co-authored-by: Lucas MindĂȘllo de Andrade --- .pre-commit-config.yaml | 4 ++-- midealocal/devices/e2/__init__.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6accf789..cd040734 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,14 +16,14 @@ repos: - id: no-commit-to-branch args: ["--branch", "main"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.2 + rev: v0.5.4 hooks: - id: ruff args: - --fix - id: ruff-format - repo: https://github.com/commitizen-tools/commitizen - rev: v3.27.0 + rev: v3.28.0 hooks: - id: commitizen stages: [commit-msg] diff --git a/midealocal/devices/e2/__init__.py b/midealocal/devices/e2/__init__.py index 98748e95..ebc874f3 100644 --- a/midealocal/devices/e2/__init__.py +++ b/midealocal/devices/e2/__init__.py @@ -103,8 +103,7 @@ def _normalize_old_protocol(self, value: str | bool | int) -> OldProtocol: if return_value == OldProtocol.auto: result = ( self.subtype <= E2SubType.T82 - or self.subtype == E2SubType.T85 - or self.subtype == E2SubType.T36353 + or self.subtype in [E2SubType.T85, E2SubType.T36353], ) return_value = OldProtocol.true if result else OldProtocol.false if isinstance(value, bool | int): From 119879df214ccbe510fef32a71fc32c96dd8008f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Mind=C3=AAllo=20de=20Andrade?= Date: Tue, 23 Jul 2024 14:25:49 -0300 Subject: [PATCH 5/6] feat(message): body parsers (#235) ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Introduced a new framework for parsing message bodies, allowing for flexible and type-safe handling of various data types. - Added specific parsers for boolean, integer, and enumeration types to the message parsing logic. - Enhanced the `MessageBody` class with a method to parse multiple attributes automatically. - **Tests** - Added a comprehensive suite of unit tests for the message parsing features, ensuring functionality and error handling across various parser classes. --------- Co-authored-by: Simone Chemelli --- midealocal/message.py | 151 ++++++++++++++++++++++++++++++++- tests/message_test.py | 189 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 tests/message_test.py diff --git a/midealocal/message.py b/midealocal/message.py index cec603fa..070d05ad 100644 --- a/midealocal/message.py +++ b/midealocal/message.py @@ -2,7 +2,7 @@ import logging from enum import IntEnum -from typing import SupportsIndex, cast +from typing import Generic, SupportsIndex, TypeVar, cast _LOGGER = logging.getLogger(__name__) @@ -300,12 +300,156 @@ def body(self) -> bytearray: return bytearray([0x00] * 19) +T = TypeVar("T") +E = TypeVar("E", bound="IntEnum") + + +class BodyParser(Generic[T]): + """Body parser to decode message.""" + + def __init__( + self, + name: str, + byte: int, + bit: int | None = None, + length_in_bytes: int = 1, + first_upper: bool = True, + default_raw_value: int = 0, + ) -> None: + """Init body parser with attribute name.""" + self.name = name + self._byte = byte + self._bit = bit + self._length_in_bytes = length_in_bytes + self._first_upper = first_upper + self._default_raw_value = default_raw_value + if length_in_bytes < 0: + raise ValueError("Length in bytes must be a positive value.") + if bit is not None and (bit < 0 or bit >= length_in_bytes * 8): + raise ValueError( + "Bit, if set, must be a valid value position for %d bytes.", + length_in_bytes, + ) + + def _get_raw_value(self, body: bytearray) -> int: + """Get raw value from body.""" + if len(body) < self._byte + self._length_in_bytes: + return self._default_raw_value + data = 0 + for i in range(self._length_in_bytes): + byte = ( + self._byte + self._length_in_bytes - 1 - i + if self._first_upper + else self._byte + i + ) + data += body[byte] << (8 * i) + if self._bit is not None: + data = (data & (1 << self._bit)) >> self._bit + return data + + def get_value(self, body: bytearray) -> T: + """Get attribute value.""" + return self._parse(self._get_raw_value(body)) + + def _parse(self, raw_value: int) -> T: + """Convert raw value to attribute value.""" + raise NotImplementedError + + +class BoolParser(BodyParser[bool]): + """Bool message body parser.""" + + def __init__( + self, + name: str, + byte: int, + bit: int | None = None, + true_value: int = 1, + false_value: int = 0, + default_value: bool = True, + ) -> None: + """Init bool body parser.""" + super().__init__(name, byte, bit) + self._true_value = true_value + self._default_value = default_value + self._false_value = false_value + + def _parse(self, raw_value: int) -> bool: + if raw_value not in [self._true_value, self._false_value]: + return self._default_value + return raw_value == self._true_value + + +class IntEnumParser(BodyParser[E]): + """IntEnum message body parser.""" + + def __init__( + self, + name: str, + byte: int, + enum_class: type[E], + length_in_bytes: int = 1, + first_upper: bool = False, + default_value: E | None = None, + ) -> None: + """Init IntEnum body parser.""" + super().__init__( + name, + byte, + length_in_bytes=length_in_bytes, + first_upper=first_upper, + ) + self._enum_class = enum_class + self._default_value = default_value + + def _parse(self, raw_value: int) -> E: + try: + return self._enum_class(raw_value) + except ValueError: + return ( + self._default_value + if self._default_value is not None + else self._enum_class(0) + ) + + +class IntParser(BodyParser[int]): + """IntEnum message body parser.""" + + def __init__( + self, + name: str, + byte: int, + max_value: int = 255, + min_value: int = 0, + length_in_bytes: int = 1, + first_upper: bool = False, + ) -> None: + """Init IntEnum body parser.""" + super().__init__( + name, + byte, + length_in_bytes=length_in_bytes, + first_upper=first_upper, + ) + self._max_value = max_value + self._min_value = min_value + + def _parse(self, raw_value: int) -> int: + if raw_value > self._max_value: + return self._max_value + if raw_value < self._min_value: + return self._min_value + return raw_value + + class MessageBody: """Message body.""" def __init__(self, body: bytearray) -> None: """Initialize message body.""" self._data = body + self.parser_list: list[BodyParser] = [] @property def data(self) -> bytearray: @@ -322,6 +466,11 @@ def read_byte(body: bytearray, byte: int, default_value: int = 0) -> int: """Read bytes for message body.""" return body[byte] if len(body) > byte else default_value + def parse_all(self) -> None: + """Process parses and set body attrs.""" + for parse in self.parser_list: + setattr(self, parse.name, parse.get_value(self._data)) + class NewProtocolPackLength(IntEnum): """New Protocol Pack Length.""" diff --git a/tests/message_test.py b/tests/message_test.py new file mode 100644 index 00000000..f9a43283 --- /dev/null +++ b/tests/message_test.py @@ -0,0 +1,189 @@ +"""Midea local message test.""" + +import pytest + +from midealocal.message import ( + BodyParser, + BodyType, + BoolParser, + IntEnumParser, + IntParser, + MessageBody, +) + + +def test_init_validations() -> None: + """Test body parser init validations.""" + with pytest.raises( + ValueError, + match="Length in bytes must be a positive value.", + ): + BodyParser[int]("name", byte=3, length_in_bytes=-1) + + with pytest.raises( + ValueError, + match="('Bit, if set, must be a valid value position for %d bytes.', 2)", + ): + BodyParser[int]("name", byte=3, length_in_bytes=2, bit=-1) + + with pytest.raises( + ValueError, + match="('Bit, if set, must be a valid value position for %d bytes.', 3)", + ): + BodyParser[int]("name", byte=3, length_in_bytes=3, bit=24) + + +class TestBodyParser: + """Body parser test.""" + + @pytest.fixture(autouse=True) + def _setup_body(self) -> None: + """Create body for test.""" + self.body = bytearray( + [ + 0x00, + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + ], + ) + + def test_get_raw_value_1_byte(self) -> None: + """Test get raw value with 1 byte.""" + parser = BodyParser[int]("name", 2) + value = parser._get_raw_value(self.body) + assert value == 0x02 + + def test_get_raw_value_2_bytes(self) -> None: + """Test get raw value with 2 bytes.""" + parser = BodyParser[int]("name", 2, length_in_bytes=2) + value = parser._get_raw_value(self.body) + assert value == 0x0203 + + def test_get_raw_value_2_bytes_first_lower(self) -> None: + """Test get raw value with 2 bytes first lower.""" + parser = BodyParser[int]("name", 2, length_in_bytes=2, first_upper=False) + value = parser._get_raw_value(self.body) + assert value == 0x0302 + + def test_get_raw_out_of_bounds(self) -> None: + """Test get raw value out of bounds.""" + parser = BodyParser[int]("name", 6) + value = parser._get_raw_value(self.body) + assert value == 0 + + def test_get_raw_data_size_out_of_bounds(self) -> None: + """Test get raw value out of bounds.""" + parser = BodyParser[int]("name", 5, length_in_bytes=2) + value = parser._get_raw_value(self.body) + assert value == 0 + + def test_get_raw_data_bit(self) -> None: + """Test get raw value out of bounds.""" + for i in range(16): + parser = BodyParser[int]("name", 4, length_in_bytes=2, bit=i) + value = parser._get_raw_value(self.body) + assert value == (1 if i in [0, 2, 10] else 0) + + def test_parse_unimplemented(self) -> None: + """Test parse unimplemented.""" + parser = BodyParser[int]("name", 4, length_in_bytes=2, bit=2) + with pytest.raises(NotImplementedError): + parser.get_value(self.body) + + +class TestBoolParser: + """Test BoolParser.""" + + def test_bool_default(self) -> None: + """Test default behaviour.""" + parser = BoolParser("name", 0) + assert parser._parse(0) is False + assert parser._parse(1) is True + assert parser._parse(2) is True + + def test_bool_default_false(self) -> None: + """Test default behaviour with default value false.""" + parser = BoolParser("name", 0, default_value=False) + assert parser._parse(0) is False + assert parser._parse(1) is True + assert parser._parse(2) is False + + def test_bool_inverted(self) -> None: + """Test True=0 and False=1.""" + parser = BoolParser("name", 0, true_value=0, false_value=1) + assert parser._parse(0) is True + assert parser._parse(1) is False + assert parser._parse(2) is True + + +class TestIntEnumParser: + """Test IntEnumParser.""" + + def test_intenum_default(self) -> None: + """Test default behaviour.""" + parser = IntEnumParser[BodyType]("name", 0, BodyType) + assert parser._parse(0x01) == BodyType.X01 + assert parser._parse(0x00) == BodyType.X00 + assert parser._parse(0x10) == BodyType.X00 + + parser = IntEnumParser[BodyType]("name", 0, BodyType, default_value=BodyType.A0) + assert parser._parse(0x01) == BodyType.X01 + assert parser._parse(0x00) == BodyType.X00 + assert parser._parse(0x10) == BodyType.A0 + + +class TestIntParser: + """Test IntParser.""" + + def test_int_default(self) -> None: + """Test default behaviour.""" + parser = IntParser("name", 0) + for i in range(-10, 260): + if i < 0: + assert parser._parse(i) == 0 + elif i > 255: + assert parser._parse(i) == 255 + else: + assert parser._parse(i) == i + + +class TestMessageBody: + """Test message body.""" + + def test_parse_all(self) -> None: + """Test parse all.""" + data = bytearray( + [ + 0x00, + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + ], + ) + + body = MessageBody(data) + body.parser_list.extend( + [ + IntEnumParser("bt", 0, BodyType), + BoolParser("power", 1), + BoolParser("feature_1", 2, 0), + BoolParser("feature_2", 2, 1), + IntParser("speed", 3), + ], + ) + body.parse_all() + assert hasattr(body, "bt") is True + assert getattr(body, "bt", None) == BodyType.X00 + assert hasattr(body, "power") is True + assert getattr(body, "power", False) is True + assert hasattr(body, "feature_1") is True + assert getattr(body, "feature_1", True) is False + assert hasattr(body, "feature_2") is True + assert getattr(body, "feature_2", False) is True + assert hasattr(body, "speed") is True + assert getattr(body, "speed", 0) == 3 From c39065ca6fb2a6ba89a3e8352c567d02896ac36f Mon Sep 17 00:00:00 2001 From: Simone Chemelli Date: Wed, 24 Jul 2024 06:56:40 +0000 Subject: [PATCH 6/6] chore: improve socket management --- midealocal/device.py | 18 +++++++----------- midealocal/exceptions.py | 4 ++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index a5a3574e..ad15bf18 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -8,7 +8,7 @@ from enum import IntEnum, StrEnum from typing import Any -from .exceptions import SocketException +from .exceptions import CannotConnect, SocketException from .message import ( MessageApplianceResponse, MessageQueryAppliance, @@ -190,13 +190,11 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]: break return result, msg - def _authenticate_refresh_enable(self) -> bool: - connected = self.connect() + def _authenticate_refresh_capabilities(self) -> None: if self._protocol == ProtocolVersion.V3: self.authenticate() self.refresh_status(wait_response=True) self.get_capabilities() - return connected def connect(self) -> bool: """Connect to device.""" @@ -234,9 +232,6 @@ def connect(self) -> bool: file, lineno, ) - finally: - if self._socket: - self._socket.close() self.enable_device(connected) return connected @@ -515,13 +510,14 @@ def _check_heartbeat(self, now: float) -> None: def run(self) -> None: """Run loop.""" while self._is_run: - if not self._socket or not self.connect(): + if not self.connect(): + raise CannotConnect + if not self._socket: raise SocketException - self._authenticate_refresh_enable() + self._authenticate_refresh_capabilities() timeout_counter = 0 start = time.time() - self._previous_refresh = start - self._previous_heartbeat = start + self._previous_refresh = self._previous_heartbeat = start self._socket.settimeout(1) while True: try: diff --git a/midealocal/exceptions.py b/midealocal/exceptions.py index f2c7ba95..7cd5c74d 100644 --- a/midealocal/exceptions.py +++ b/midealocal/exceptions.py @@ -11,6 +11,10 @@ class CannotAuthenticate(MideaLocalError): """Exception raised when credentials are incorrect.""" +class CannotConnect(MideaLocalError): + """Exception raised when connection fails.""" + + class DataUnexpectedLength(MideaLocalError): """Exception raised when data length is less or more than expected."""