Skip to content

Commit

Permalink
feat: segregate connect/auth/refresh/enable device duties
Browse files Browse the repository at this point in the history
  • Loading branch information
chemelli74 committed Jul 19, 2024
1 parent 8178e76 commit ae8f1ea
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 56 deletions.
99 changes: 49 additions & 50 deletions midealocal/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,59 +190,57 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]:
break
return result, msg

def connect(
self,
refresh_status: bool = True,
get_capabilities: bool = True,
) -> bool:
def _authenticate_refresh_enable(self) -> bool:
connected = self.connect()
if self._protocol == ProtocolVersion.V3:
self.authenticate()
self.refresh_status(wait_response=True)
self.get_capabilities()
self.enable_device(connected)
return connected

def connect(self) -> bool:
"""Connect to device."""
connected = False
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
if self._protocol == ProtocolVersion.V3:
self.authenticate()
_LOGGER.debug("[%s] Authentication success", self._device_id)
if refresh_status:
self.refresh_status(wait_response=True)
if get_capabilities:
self.get_capabilities()
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
self.enable_device(connected)
for _ in range(3):
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
return connected

def authenticate(self) -> None:
"""Authenticate to device. V3 only."""
request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST)
_LOGGER.debug("[%s] Handshaking", self._device_id)
_LOGGER.debug("[%s] Authentication handshaking", self._device_id)
if not self._socket:
raise SocketException
self._socket.send(request)
Expand All @@ -251,6 +249,7 @@ def authenticate(self) -> None:
raise AuthException
response = response[8:72]
self._security.tcp_key(response, self._key)
_LOGGER.debug("[%s] Authentication success", self._device_id)

def send_message(self, data: bytes) -> None:
"""Send message."""
Expand Down Expand Up @@ -462,6 +461,7 @@ def update_all(self, status: dict[str, Any]) -> None:

def enable_device(self, available: bool = True) -> None:
"""Enable device."""
_LOGGER.debug("[%s] Enabling device", self._device_id)
self._available = available
status = {"available": available}
self.update_all(status)
Expand Down Expand Up @@ -510,10 +510,9 @@ def _check_heartbeat(self, now: float) -> None:
def run(self) -> None:
"""Run loop."""
while self._is_run:
while self._socket is None:
if self.connect(refresh_status=True) is False:
self.close_socket()
time.sleep(5)
if not self._socket or not self.connect():
raise SocketException
self._authenticate_refresh_enable()
timeout_counter = 0
start = time.time()
self._previous_refresh = start
Expand Down
12 changes: 6 additions & 6 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,27 @@ def test_connect(self) -> None:
None,
None,
]
assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

assert self.device.connect(True, True) is True
assert self.device.connect() is True
assert self.device.available is True

def test_connect_generic_exception(self) -> None:
"""Test connect with generic exception."""
with patch("socket.socket.connect") as connect_mock:
connect_mock.side_effect = Exception()

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

def test_authenticate(self) -> None:
Expand Down

0 comments on commit ae8f1ea

Please sign in to comment.