From f9f747590eea49f5feff3c04b724fc3c27f96a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Mind=C3=AAllo=20de=20Andrade?= Date: Fri, 6 Sep 2024 12:56:24 +0000 Subject: [PATCH 1/4] fix(device): socket exception and process rebuild --- midealocal/cli.py | 9 +- midealocal/device.py | 361 +++++++++++++++++++++++++++---------------- tests/cli_test.py | 6 +- tests/device_test.py | 29 ++-- 4 files changed, 253 insertions(+), 152 deletions(-) diff --git a/midealocal/cli.py b/midealocal/cli.py index 11e7b453..dcad5332 100644 --- a/midealocal/cli.py +++ b/midealocal/cli.py @@ -20,7 +20,12 @@ get_midea_cloud, get_preset_account_cloud, ) -from midealocal.device import AuthException, MideaDevice, ProtocolVersion, RefreshFailed +from midealocal.device import ( + AuthException, + MideaDevice, + NoSupportedProtocol, + ProtocolVersion, +) from midealocal.devices import device_selector from midealocal.discover import discover from midealocal.exceptions import SocketException @@ -121,7 +126,7 @@ async def discover(self) -> list[MideaDevice]: _LOGGER.debug("Unable to connect with key: %s", key) except SocketException: _LOGGER.exception("Device socket closed.") - except RefreshFailed: + except NoSupportedProtocol: _LOGGER.exception("Unable to retrieve device attributes.") else: _LOGGER.info("Found device:\n%s", dev.attributes) diff --git a/midealocal/device.py b/midealocal/device.py index 4a2bc460..b0cb6c16 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -10,7 +10,7 @@ from typing_extensions import deprecated -from .exceptions import CannotConnect, SocketException +from .exceptions import SocketException from .message import ( MessageApplianceResponse, MessageQueryAppliance, @@ -28,7 +28,9 @@ MIN_AUTH_RESPONSE = 20 MIN_MSG_LENGTH = 56 MIN_V2_FACTUAL_MSG_LENGTH = 6 -RESPONSE_TIMEOUT = 120 +SOCKET_TIMEOUT = 10 # socket connection default timeout +QUERY_TIMEOUT = 2 # query response in 1s, 0xAC have more queries, set to 2s + _LOGGER = logging.getLogger(__name__) @@ -81,8 +83,8 @@ class ResponseException(Exception): """Response exception.""" -class RefreshFailed(Exception): - """Refresh failed exception.""" +class NoSupportedProtocol(Exception): + """Query device failed exception.""" class DeviceAttributes(StrEnum): @@ -97,11 +99,14 @@ class ProtocolVersion(IntEnum): V3 = 3 -class ParseMessageResult(IntEnum): +class MessageResult(IntEnum): """Parse message result.""" - SUCCESS = 0 - PADDING = 1 + PADDING = 0 + SUCCESS = 1 + UNKNOWN = 96 + UNEXPECTED = 97 + TIMEOUT = 98 ERROR = 99 @@ -145,7 +150,7 @@ def __init__( self._available = False self._appliance_query = True self._refresh_interval = 30 - self._heartbeat_interval = 10 + self._heartbeat_interval = SOCKET_TIMEOUT self._default_refresh_interval = 30 self._previous_refresh = 0.0 self._previous_heartbeat = 0.0 @@ -192,50 +197,38 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]: break return result, msg - def _authenticate_refresh_capabilities(self) -> None: - if self._protocol == ProtocolVersion.V3: - self.authenticate() - self.refresh_status(wait_response=True) - self.get_capabilities() - def connect(self) -> bool: """Connect to device.""" connected = False - for _ in range(3): - try: - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.settimeout(10) - _LOGGER.debug( - "[%s] Connecting to %s:%s", - self._device_id, - self._ip_address, - self._port, - ) - self._socket.connect((self._ip_address, self._port)) - _LOGGER.debug("[%s] Connected", self._device_id) - connected = True - break - except TimeoutError: - _LOGGER.debug("[%s] Connection timed out", self._device_id) - except OSError: - _LOGGER.debug("[%s] Connection error", self._device_id) - except AuthException: - _LOGGER.debug("[%s] Authentication failed", self._device_id) - except RefreshFailed: - _LOGGER.debug("[%s] Refresh status is timed out", self._device_id) - except Exception as e: - file = None - lineno = None - if e.__traceback__: - file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 - lineno = e.__traceback__.tb_lineno - _LOGGER.exception( - "[%s] Unknown error : %s, %s", - self._device_id, - file, - lineno, - ) - self.enable_device(connected) + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.settimeout(SOCKET_TIMEOUT) + _LOGGER.debug( + "[%s] Connecting to %s:%s", + self._device_id, + self._ip_address, + self._port, + ) + self._socket.connect((self._ip_address, self._port)) + _LOGGER.debug("[%s] Connected", self._device_id) + connected = True + except TimeoutError: + _LOGGER.debug("[%s] Connection timed out", self._device_id) + # set _socket to None when connect exception matched + self._socket = None + except OSError: + _LOGGER.debug("[%s] Connection error", self._device_id) + # set _socket to None when connect exception matched + self._socket = None + except Exception as e: + _LOGGER.exception( + "[%s] Unknown error during connect device", + self._device_id, + exc_info=e, + ) + # set _socket to None when connect exception matched + self._socket = None + self.set_available(connected) return connected def authenticate(self) -> None: @@ -245,8 +238,16 @@ def authenticate(self) -> None: if not self._socket: self.enable_device(False) raise SocketException - self._socket.send(request) - response = self._socket.recv(512) + try: + self._socket.send(request) + response = self._socket.recv(512) + except Exception as e: + _LOGGER.exception( + "[%s] authenticate Unexpected socket error", + self._device_id, + exc_info=e, + ) + self.close_socket() _LOGGER.debug( "[%s] Received auth response with %d bytes: %s", self._device_id, @@ -260,20 +261,30 @@ def authenticate(self) -> None: self._security.tcp_key(response, self._key) _LOGGER.debug("[%s] Authentication success", self._device_id) - def send_message(self, data: bytes) -> None: + def send_message(self, data: bytes, query: bool = False) -> None: """Send message.""" if self._protocol == ProtocolVersion.V3: - self.send_message_v3(data, msg_type=MSGTYPE_ENCRYPTED_REQUEST) + self.send_message_v3(data, msg_type=MSGTYPE_ENCRYPTED_REQUEST, query=query) else: - self.send_message_v2(data) + self.send_message_v2(data, query=query) - def send_message_v2(self, data: bytes) -> None: + def send_message_v2(self, data: bytes, query: bool = False) -> None: """Send message V2.""" if self._socket is not None: - self._socket.send(data) + try: + if query: + self._socket.settimeout(QUERY_TIMEOUT) + self._socket.send(data) + except Exception as e: + _LOGGER.exception( + "[%s] send_message_v2 Unexpected socket error", + self._device_id, + exc_info=e, + ) + self.close_socket() else: _LOGGER.debug( - "[%s] Send failure, device disconnected, data: %s", + "[%s] Send failure, device socket is none, data: %s", self._device_id, data.hex(), ) @@ -282,17 +293,18 @@ def send_message_v3( self, data: bytes, msg_type: int = MSGTYPE_ENCRYPTED_REQUEST, + query: bool = False, ) -> None: """Send message V3.""" data = self._security.encode_8370(data, msg_type) - self.send_message_v2(data) + self.send_message_v2(data, query=query) - def build_send(self, cmd: MessageRequest) -> None: + def build_send(self, cmd: MessageRequest, query: bool = False) -> None: """Serialize and send.""" data = cmd.serialize() _LOGGER.debug("[%s] Sending: %s", self._device_id, cmd) msg = PacketBuilder(self._device_id, data).finalize() - self.send_message(msg) + self.send_message(msg, query=query) def get_capabilities(self) -> None: """Get device capabilities.""" @@ -300,7 +312,35 @@ def get_capabilities(self) -> None: for cmd in cmds: self.build_send(cmd) - def refresh_status(self, wait_response: bool = False) -> None: + def _recv_message(self) -> dict[str, MessageResult | bytes]: + """Recv message.""" + if not self._socket: + _LOGGER.warning("[%s] _recv_message socket error", self._device_id) + raise SocketException + try: + msg = self._socket.recv(512) + if len(msg) == 0: + _LOGGER.warning("[%s] Empty msg received", self._device_id) + return {"result": MessageResult.PADDING} + if msg: + return {"result": MessageResult.SUCCESS, "msg": msg} + except TimeoutError: + _LOGGER.debug("[%s] _recv_message Socket timed out", self._device_id) + # close socket when exception matched + self.close_socket() + return {"result": MessageResult.TIMEOUT} + except Exception as e: + _LOGGER.exception( + "[%s] Unexpected socket error", + self._device_id, + exc_info=e, + ) + # close socket when exception matched + self.close_socket() + return {"result": MessageResult.UNEXPECTED} + return {"result": MessageResult.UNKNOWN} # Add a fallback return + + def refresh_status(self, check_protocol: bool = False) -> None: """Refresh device status.""" cmds: list = self.build_query() if self._appliance_query: @@ -308,40 +348,77 @@ def refresh_status(self, wait_response: bool = False) -> None: error_count = 0 for cmd in cmds: if cmd.__class__.__name__ not in self._unsupported_protocol: - self.build_send(cmd) - if wait_response: - try: - while True: - if not self._socket: - raise SocketException - msg = self._socket.recv(512) - if len(msg) == 0: - raise OSError("Empty message received.") - result = self.parse_message(msg) - if result == ParseMessageResult.SUCCESS: + # set query flag for query timeout + self.build_send(cmd, query=True) + response = self._recv_message() + # recovery timeout after _recv_message + self._recovery_timeout() + # normal msg + if response.get("result") == MessageResult.SUCCESS: + if response.get("msg"): + # parse response + msg = response.get("msg") + if isinstance(msg, bytes): + result = self.parse_message(msg=msg) + if result == MessageResult.SUCCESS: break - if result == ParseMessageResult.PADDING: - continue - error_count += 1 - except TimeoutError: + # msg padding + continue + # empty msg + elif response.get("result") == MessageResult.PADDING: + continue + # timeout msg + elif response.get("result") == MessageResult.TIMEOUT: + _LOGGER.debug( + "[%s] protocol %s, cmd %s, timeout", + self._device_id, + cmd.__class__.__name__, + cmd, + ) + # init connection, add timeout protocol to unsupported list + if check_protocol: error_count += 1 self._unsupported_protocol.append(cmd.__class__.__name__) _LOGGER.debug( - "[%s] Does not supports the protocol %s, ignored", + "[%s] Does not supports the protocol %s, cmd %s, ignored", self._device_id, cmd.__class__.__name__, + cmd, ) - else: - error_count += 1 - if error_count == len(cmds): - raise RefreshFailed + # exception msg + else: + _LOGGER.debug( + "[%s] protocol %s, cmd %s, response exception %s", + self._device_id, + cmd.__class__.__name__, + cmd, + response, + ) + # init connection, add exception protocol to unsupported list + if check_protocol: + error_count += 1 + self._unsupported_protocol.append(cmd.__class__.__name__) + _LOGGER.debug( + "[%s] Does not supports the protocol %s, cmd %s, ignored", + self._device_id, + cmd.__class__.__name__, + cmd, + ) + # init connection and all the query failed, raise error + if check_protocol and error_count == len(cmds): + _LOGGER.debug( + "[%s] all the query cmds failed %s, please report bug", + self._device_id, + cmds, + ) + raise NoSupportedProtocol def pre_process_message(self, msg: bytearray) -> bool: """Pre process message.""" if msg[9] == MessageType.query_appliance: message = MessageApplianceResponse(msg) self._appliance_query = False - _LOGGER.debug("[%s] Received: %s", self._device_id, message) + _LOGGER.debug("[%s] Appliance query Received: %s", self._device_id, message) self._protocol_version = message.protocol_version _LOGGER.debug( "[%s] Device protocol version: %s", @@ -351,17 +428,17 @@ def pre_process_message(self, msg: bytearray) -> bool: return False return True - def parse_message(self, msg: bytes) -> ParseMessageResult: + def parse_message(self, msg: bytes) -> MessageResult: """Parse message.""" if self._protocol == ProtocolVersion.V3: messages, self._buffer = self._security.decode_8370(self._buffer + msg) else: messages, self._buffer = self.fetch_v2_message(self._buffer + msg) if len(messages) == 0: - return ParseMessageResult.PADDING + return MessageResult.PADDING for message in messages: if message == b"ERROR": - return ParseMessageResult.ERROR + return MessageResult.ERROR payload_len = message[4] + (message[5] << 8) - 56 payload_type = message[2] + (message[3] << 8) if payload_type in [0x1001, 0x0001]: @@ -384,7 +461,6 @@ def parse_message(self, msg: bytes) -> ParseMessageResult: "[%s] Unidentified protocol", self._device_id, ) - except Exception: _LOGGER.exception( "[%s] Error in process message, msg = %s", @@ -419,7 +495,7 @@ def parse_message(self, msg: bytes) -> ParseMessageResult: payload_len, len(message), ) - return ParseMessageResult.SUCCESS + return MessageResult.SUCCESS def build_query(self) -> list: """Build query.""" @@ -501,8 +577,13 @@ def close_socket(self) -> None: self._unsupported_protocol = [] self._buffer = b"" if self._socket: - self._socket.close() - self._socket = None + try: + self._socket.close() + _LOGGER.debug("[%s] Socket closed", self._device_id) + except OSError as e: + _LOGGER.debug("[%s] Error while closing socket: %s", self._device_id, e) + finally: + self._socket = None def set_ip_address(self, ip_address: str) -> None: """Set IP address.""" @@ -525,63 +606,83 @@ def _check_heartbeat(self, now: float) -> None: self.send_heartbeat() self._previous_heartbeat = now + def _recovery_timeout(self) -> None: + if not self._socket: + _LOGGER.warning("[%s] _recovery_timeout socket error", self._device_id) + raise SocketException + try: + self._socket.settimeout(SOCKET_TIMEOUT) + except TimeoutError: + self.close_socket() + _LOGGER.debug("_recovery_timeout socket timeout") + def run(self) -> None: """Run loop.""" + connection_retries = 0 while self._is_run: - if not self.connect(): - raise CannotConnect - if not self._socket: - raise SocketException - self._authenticate_refresh_capabilities() - timeout_counter = 0 + # init connection or socket broken, socket connect/reconnect + while self._socket is None: + _LOGGER.debug("[%s] Socket is None, try to connect", self._device_id) + # connect and check result + if not self.connect(): + self.close_socket() + connection_retries += 1 + sleep_time = min(60 * connection_retries, 600) + _LOGGER.warning( + "[%s] Unable to connect, sleep %s seconds and retry", + self._device_id, + sleep_time, + ) + # sleep and reconnect loop + time.sleep(sleep_time) + continue + # connect pass, auth for v3 device + if self._protocol == ProtocolVersion.V3: + self.authenticate() + try: + # probe device with query and check response + self.refresh_status(check_protocol=True) + except NoSupportedProtocol: + _LOGGER.debug( + "[%s] query device failed, please report bug", + self._device_id, + ) + break + except SocketException: + _LOGGER.debug( + "[%s] socket error, close and reconnect", + self._device_id, + ) + self.close_socket() + continue + self.get_capabilities() + # socket exist + connection_retries = 0 start = time.time() self._previous_refresh = self._previous_heartbeat = start - self._socket.settimeout(1) + # loop in query and parse response while True: try: + # check refresh process now = time.time() self._check_refresh(now) + # check heartbeat + now = time.time() self._check_heartbeat(now) - msg = self._socket.recv(512) - if len(msg) == 0: - if self._is_run: - _LOGGER.error( - "[%s] Socket error - Connection closed by peer", - self._device_id, - ) - self.close_socket() - break - result = self.parse_message(msg) - if result == ParseMessageResult.ERROR: - _LOGGER.debug("[%s] Message 'ERROR' received", self._device_id) - self.close_socket() - break - if result == ParseMessageResult.SUCCESS: - timeout_counter = 0 except TimeoutError: - timeout_counter = timeout_counter + 1 - if timeout_counter >= RESPONSE_TIMEOUT: - _LOGGER.debug("[%s] Heartbeat timed out", self._device_id) - self.close_socket() - break - except OSError as e: - if self._is_run: - _LOGGER.debug("[%s] Socket error %s", self._device_id, repr(e)) - self.close_socket() + _LOGGER.debug("[%s] Socket timed out", self._device_id) + self.close_socket() + break + except NoSupportedProtocol: + _LOGGER.debug("[%s] query device failed", self._device_id) + self.close_socket() break except Exception as e: - file = None - lineno = None - if e.__traceback__: - file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 - lineno = e.__traceback__.tb_lineno _LOGGER.exception( - "[%s] Unknown error : %s, %s", + "[%s] Unexpected error", self._device_id, - file, - lineno, + exc_info=e, ) - self.close_socket() break diff --git a/tests/cli_test.py b/tests/cli_test.py index b1b663f0..a7c9d006 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -14,7 +14,7 @@ get_config_file_path, ) from midealocal.cloud import MSmartHomeCloud -from midealocal.device import AuthException, ProtocolVersion, RefreshFailed +from midealocal.device import AuthException, NoSupportedProtocol, ProtocolVersion from midealocal.exceptions import SocketException @@ -134,7 +134,7 @@ async def test_discover(self) -> None: patch.object( mock_device_instance, "refresh_status", - side_effect=[None, None, RefreshFailed, None], + side_effect=[None, None, NoSupportedProtocol, None], ) as refresh_status_mock, ): mock_discover.return_value = {1: mock_device} @@ -157,7 +157,7 @@ async def test_discover(self) -> None: authenticate_mock.reset_mock() mock_device["protocol"] = ProtocolVersion.V2 - await self.cli.discover() # V2 device RefreshFailed + await self.cli.discover() # V2 device NoSupportedProtocol authenticate_mock.assert_not_called() refresh_status_mock.assert_called_once() diff --git a/tests/device_test.py b/tests/device_test.py index 88b8ba03..21fb699a 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -7,10 +7,10 @@ from midealocal.cloud import DEFAULT_KEYS from midealocal.device import ( AuthException, + MessageResult, MideaDevice, - ParseMessageResult, + NoSupportedProtocol, ProtocolVersion, - RefreshFailed, ) from midealocal.devices.ac.message import MessageCapabilitiesQuery from midealocal.exceptions import SocketException @@ -64,7 +64,7 @@ def test_initial_attributes(self) -> None: (TimeoutError, False), (OSError, False), (AuthException, False), - (RefreshFailed, False), + (NoSupportedProtocol, False), (None, True), ], ) @@ -229,9 +229,9 @@ def test_refresh_status(self) -> None: self.device, "parse_message", side_effect=[ - ParseMessageResult.SUCCESS, - ParseMessageResult.PADDING, - ParseMessageResult.ERROR, + MessageResult.SUCCESS, + MessageResult.PADDING, + MessageResult.ERROR, ], ), ): @@ -246,11 +246,11 @@ def test_refresh_status(self) -> None: self.device.refresh_status(True) # SUCCESS self.device.refresh_status(True) # PADDING - with pytest.raises(RefreshFailed): + with pytest.raises(NoSupportedProtocol): self.device.refresh_status(True) # ERROR - with pytest.raises(RefreshFailed): + with pytest.raises(NoSupportedProtocol): self.device.refresh_status(True) # Timeout - with pytest.raises(RefreshFailed): + with pytest.raises(NoSupportedProtocol): self.device.refresh_status(True) # Unsupported protocol def test_parse_message(self) -> None: @@ -281,20 +281,15 @@ def test_parse_message(self) -> None: ], ), ): - assert ( - self.device.parse_message(bytearray([])) == ParseMessageResult.PADDING - ) + assert self.device.parse_message(bytearray([])) == MessageResult.PADDING self.device._protocol = ProtocolVersion.V2 - assert self.device.parse_message(bytearray([])) == ParseMessageResult.ERROR + assert self.device.parse_message(bytearray([])) == MessageResult.ERROR with patch.object( self.device, "process_message", side_effect=[{"power": True}, {}, NotImplementedError()], ): - assert ( - self.device.parse_message(bytearray([])) - == ParseMessageResult.SUCCESS - ) + assert self.device.parse_message(bytearray([])) == MessageResult.SUCCESS def test_pre_process_message(self) -> None: """Test pre process message.""" From 5ff3686432c94e8c081ed3f1b243f66d80830ae3 Mon Sep 17 00:00:00 2001 From: wuwentao Date: Thu, 19 Sep 2024 14:37:07 +0000 Subject: [PATCH 2/4] chore: fix check_protocol got timeout and not close socket --- midealocal/device.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index b0cb6c16..0ff01de3 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -312,7 +312,10 @@ def get_capabilities(self) -> None: for cmd in cmds: self.build_send(cmd) - def _recv_message(self) -> dict[str, MessageResult | bytes]: + def _recv_message( + self, + check_protocol: bool = False, + ) -> dict[str, MessageResult | bytes]: """Recv message.""" if not self._socket: _LOGGER.warning("[%s] _recv_message socket error", self._device_id) @@ -325,9 +328,14 @@ def _recv_message(self) -> dict[str, MessageResult | bytes]: if msg: return {"result": MessageResult.SUCCESS, "msg": msg} except TimeoutError: - _LOGGER.debug("[%s] _recv_message Socket timed out", self._device_id) - # close socket when exception matched - self.close_socket() + _LOGGER.debug( + "[%s] _recv_message Socket timed out with check_protocol %s", + self._device_id, + check_protocol, + ) + # close socket when timeout and not check_protocol + if not check_protocol: + self.close_socket() return {"result": MessageResult.TIMEOUT} except Exception as e: _LOGGER.exception( @@ -350,7 +358,7 @@ def refresh_status(self, check_protocol: bool = False) -> None: if cmd.__class__.__name__ not in self._unsupported_protocol: # set query flag for query timeout self.build_send(cmd, query=True) - response = self._recv_message() + response = self._recv_message(check_protocol=check_protocol) # recovery timeout after _recv_message self._recovery_timeout() # normal msg @@ -360,10 +368,13 @@ def refresh_status(self, check_protocol: bool = False) -> None: msg = response.get("msg") if isinstance(msg, bytes): result = self.parse_message(msg=msg) - if result == MessageResult.SUCCESS: - break - # msg padding - continue + if result != MessageResult.SUCCESS: + _LOGGER.error( + "[%s] parse_message %s result is %s", + self._device_id, + msg, + result, + ) # empty msg elif response.get("result") == MessageResult.PADDING: continue From a602d9fda4be6603beefe1fc1e898a0bc0e3706f Mon Sep 17 00:00:00 2001 From: Wentao Wu Date: Fri, 20 Sep 2024 02:07:18 +0000 Subject: [PATCH 3/4] chore: fix recovery socket only recv msg success/padding --- midealocal/device.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index 0ff01de3..269a87c1 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -359,10 +359,11 @@ def refresh_status(self, check_protocol: bool = False) -> None: # set query flag for query timeout self.build_send(cmd, query=True) response = self._recv_message(check_protocol=check_protocol) - # recovery timeout after _recv_message - self._recovery_timeout() # normal msg if response.get("result") == MessageResult.SUCCESS: + # recovery timeout after _recv_message is success/padding + # for exception/timeout result, self._socket closed + self._recovery_timeout() if response.get("msg"): # parse response msg = response.get("msg") @@ -377,6 +378,9 @@ def refresh_status(self, check_protocol: bool = False) -> None: ) # empty msg elif response.get("result") == MessageResult.PADDING: + # recovery timeout after _recv_message is success/padding + # for exception/timeout result, self._socket closed + self._recovery_timeout() continue # timeout msg elif response.get("result") == MessageResult.TIMEOUT: @@ -628,7 +632,26 @@ def _recovery_timeout(self) -> None: _LOGGER.debug("_recovery_timeout socket timeout") def run(self) -> None: - """Run loop.""" + """Run loop brief description. + + 1. first/init connection, self._socket is None + 1.1 connect() device loop, pass, enable device + 1.2 auth for v3 device, MUST pass for v3 device + 1.3 init refresh_status, send all query and check supported protocol + 1.3.1 set socket timeout to QUERY_TIMEOUT before send query + 1.3.2 get response and add timeout query cmd to not supported + 1.3.1 parse recv response/status for supported protocol + 1.4 get_capabilities() + 2. after socket/device connected, loop for heartbeat/refresh_status + 3. job1: check refresh_interval + 3.1 socket/device connection should exist + 3.2 send only supported query to get response and refresh status + 3.3 set socket query timeout and recovery after recv msg + 4. job2: check heartbeat interval + 4.1 socket connection should exist + 4.2 send heartbeat packet to keep alive + + """ connection_retries = 0 while self._is_run: # init connection or socket broken, socket connect/reconnect From b2dd254f05582fad0d2cfa0cecd511606ba6becf Mon Sep 17 00:00:00 2001 From: Wentao Wu Date: Fri, 20 Sep 2024 03:47:16 +0000 Subject: [PATCH 4/4] chore: all the socket exception should continue and reconnect --- midealocal/device.py | 79 ++++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/midealocal/device.py b/midealocal/device.py index 269a87c1..0e303dc5 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -236,8 +236,9 @@ def authenticate(self) -> None: request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST) _LOGGER.debug("[%s] Authentication handshaking", self._device_id) if not self._socket: + _LOGGER.debug("[%s] socket is None, close and return", self._device_id) self.enable_device(False) - raise SocketException + return try: self._socket.send(request) response = self._socket.recv(512) @@ -256,6 +257,12 @@ def authenticate(self) -> None: ) if len(response) < MIN_AUTH_RESPONSE: self.enable_device(False) + _LOGGER.debug( + "[%s] Received auth response len %d error, bytes: %s", + self._device_id, + len(response), + response.hex(), + ) raise AuthException response = response[8:72] self._security.tcp_key(response, self._key) @@ -317,8 +324,9 @@ def _recv_message( check_protocol: bool = False, ) -> dict[str, MessageResult | bytes]: """Recv message.""" + # already connected and socket error if not self._socket: - _LOGGER.warning("[%s] _recv_message socket error", self._device_id) + _LOGGER.debug("[%s] _recv_message socket error, reconnect", self._device_id) raise SocketException try: msg = self._socket.recv(512) @@ -356,14 +364,23 @@ def refresh_status(self, check_protocol: bool = False) -> None: error_count = 0 for cmd in cmds: if cmd.__class__.__name__ not in self._unsupported_protocol: - # set query flag for query timeout - self.build_send(cmd, query=True) - response = self._recv_message(check_protocol=check_protocol) - # normal msg - if response.get("result") == MessageResult.SUCCESS: + # catch socket exception and continue + try: + # set query flag for query timeout + self.build_send(cmd, query=True) + # recv socket message send query + response = self._recv_message(check_protocol=check_protocol) # recovery timeout after _recv_message is success/padding - # for exception/timeout result, self._socket closed self._recovery_timeout() + except SocketException: + _LOGGER.debug( + "[%s] refresh_status socket error, close and reconnect", + self._device_id, + ) + self.close_socket() + break + # normal msg + if response.get("result") == MessageResult.SUCCESS: if response.get("msg"): # parse response msg = response.get("msg") @@ -378,9 +395,6 @@ def refresh_status(self, check_protocol: bool = False) -> None: ) # empty msg elif response.get("result") == MessageResult.PADDING: - # recovery timeout after _recv_message is success/padding - # for exception/timeout result, self._socket closed - self._recovery_timeout() continue # timeout msg elif response.get("result") == MessageResult.TIMEOUT: @@ -631,6 +645,27 @@ def _recovery_timeout(self) -> None: self.close_socket() _LOGGER.debug("_recovery_timeout socket timeout") + def _connect_loop(self) -> None: + """Connect loop until device online.""" + # init connection or socket broken, socket loop until device online + connection_retries = 0 + while self._socket is None: + _LOGGER.debug("[%s] Socket is None, try to connect", self._device_id) + # connect and check result + if not self.connect(): + self.close_socket() + connection_retries += 1 + sleep_time = min(60 * connection_retries, 600) + _LOGGER.warning( + "[%s] Unable to connect, sleep %s seconds and retry", + self._device_id, + sleep_time, + ) + # sleep and reconnect loop + time.sleep(sleep_time) + continue + connection_retries = 0 + def run(self) -> None: """Run loop brief description. @@ -652,24 +687,11 @@ def run(self) -> None: 4.2 send heartbeat packet to keep alive """ - connection_retries = 0 while self._is_run: - # init connection or socket broken, socket connect/reconnect - while self._socket is None: - _LOGGER.debug("[%s] Socket is None, try to connect", self._device_id) - # connect and check result - if not self.connect(): - self.close_socket() - connection_retries += 1 - sleep_time = min(60 * connection_retries, 600) - _LOGGER.warning( - "[%s] Unable to connect, sleep %s seconds and retry", - self._device_id, - sleep_time, - ) - # sleep and reconnect loop - time.sleep(sleep_time) - continue + # init connection + if self._socket is None: + # connect device loop + self._connect_loop() # connect pass, auth for v3 device if self._protocol == ProtocolVersion.V3: self.authenticate() @@ -691,7 +713,6 @@ def run(self) -> None: continue self.get_capabilities() # socket exist - connection_retries = 0 start = time.time() self._previous_refresh = self._previous_heartbeat = start # loop in query and parse response