Skip to content

Commit

Permalink
chore: typing part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
chemelli74 committed May 27, 2024
1 parent cf798af commit 69e9d94
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 34 deletions.
69 changes: 42 additions & 27 deletions midealocal/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import base64
from threading import Lock
from typing import Any, cast
from aiohttp import ClientSession
from secrets import token_hex
from .security import (
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(
account: str,
password: str,
api_url: str,
):
) -> None:
self._device_id = CloudSecurity.get_deviceid(account)
self._session = session
self._security = security
Expand All @@ -88,14 +89,16 @@ def __init__(
self._account = account
self._password = password
self._api_url = api_url
self._access_token = None
self._uid = None
self._access_token: str | None = None
self._uid: str | None = None
self._login_id = ""

def _make_general_data(self):
def _make_general_data(self) -> dict[Any, Any]:
return {}

async def _api_request(self, endpoint: str, data: dict, header=None) -> dict | None:
async def _api_request(
self, endpoint: str, data: dict, header: dict[Any, Any] | None = None
) -> dict | None:
header = header or {}
if not data.get("reqId"):
data.update({"reqId": token_hex(16)})
Expand Down Expand Up @@ -133,7 +136,7 @@ async def _api_request(self, endpoint: str, data: dict, header=None) -> dict | N
except Exception as e:
_LOGGER.warning(f"Midea cloud API error, url: {url}, error: {repr(e)}")
if int(response["code"]) == 0 and "data" in response:
return response["data"]
return cast(dict, response["data"])
return None

async def _get_login_id(self) -> str | None:
Expand All @@ -148,7 +151,7 @@ async def _get_login_id(self) -> str | None:
async def login(self) -> bool:
raise NotImplementedError()

async def get_keys(self, appliance_id: int):
async def get_keys(self, appliance_id: int) -> dict[int, dict[str, Any]]:
result = {}
for method in [1, 2]:
udp_id = self._security.get_udp_id(appliance_id, method)
Expand All @@ -167,16 +170,18 @@ async def get_keys(self, appliance_id: int):
result.update(default_keys)
return result

async def list_home(self) -> dict | None:
async def list_home(self) -> dict[int, Any] | None:
return {1: "My home"}

async def list_appliances(self, home_id) -> dict | None:
async def list_appliances(
self, home_id: str | None
) -> dict[int, dict[str, Any]] | None:
raise NotImplementedError()

async def get_device_info(self, device_id: int):
async def get_device_info(self, device_id: str) -> dict[str, Any] | None:
if response := await self.list_appliances(home_id=None):
if device_id in response.keys():
return response[device_id]
if int(device_id) in response.keys():
return cast(dict, response[int(device_id)])
return None

async def download_lua(
Expand All @@ -186,7 +191,7 @@ async def download_lua(
sn: str,
model_number: str | None,
manufacturer_code: str = "0000",
):
) -> str | None:
raise NotImplementedError()


Expand Down Expand Up @@ -250,7 +255,7 @@ async def login(self) -> bool:
return True
return False

async def list_home(self):
async def list_home(self) -> dict[int, Any] | None:
if response := await self._api_request(
endpoint="/v1/homegroup/list/get", data={}
):
Expand All @@ -260,7 +265,9 @@ async def list_home(self):
return homes
return None

async def list_appliances(self, home_id) -> dict | None:
async def list_appliances(
self, home_id: str | None
) -> dict[int, dict[str, Any]] | None:
data = {"homegroupId": home_id}
if response := await self._api_request(
endpoint="/v1/appliance/home/list/get", data=data
Expand Down Expand Up @@ -299,7 +306,7 @@ async def list_appliances(self, home_id) -> dict | None:
return appliances
return None

async def get_device_info(self, device_id: int):
async def get_device_info(self, device_id: str) -> dict[str, Any] | None:
data = {"applianceCode": device_id}
if response := await self._api_request(
endpoint="/v1/appliance/info/get", data=data
Expand Down Expand Up @@ -339,7 +346,7 @@ async def download_lua(
sn: str,
model_number: str | None,
manufacturer_code: str = "0000",
):
) -> str | None:
data = {
"applianceSn": sn,
"applianceType": "0x%02X" % device_type,
Expand Down Expand Up @@ -391,7 +398,7 @@ def __init__(
f"{self._app_key}:{clouds['MSmartHome']['iot_key']}".encode("ascii")
).decode("ascii")

def _make_general_data(self):
def _make_general_data(self) -> dict[str, Any]:
return {
# "appVersion": self.APP_VERSION,
"src": self._app_id,
Expand All @@ -405,15 +412,17 @@ def _make_general_data(self):
"appId": self._app_id,
}

async def _api_request(self, endpoint: str, data: dict, header=None) -> dict | None:
async def _api_request(
self, endpoint: str, data: dict, header: dict | None = None
) -> dict[str, Any] | None:
header = header or {}
header.update(
{"x-recipe-app": self._app_id, "authorization": f"Basic {self._auth_base}"}
)

return await super()._api_request(endpoint, data, header)

async def _re_route(self):
async def _re_route(self) -> None:
data = self._make_general_data()
data.update({"userType": "0", "userName": f"{self._account}"})
if response := await self._api_request(
Expand Down Expand Up @@ -461,7 +470,9 @@ async def login(self) -> bool:
return True
return False

async def list_appliances(self, home_id) -> dict | None:
async def list_appliances(
self, home_id: str | None
) -> dict[int, dict[str, Any]] | None:
data = self._make_general_data()
if response := await self._api_request(
endpoint="/v1/appliance/user/list/get", data=data
Expand Down Expand Up @@ -502,7 +513,7 @@ async def download_lua(
sn: str,
model_number: str | None,
manufacturer_code: str = "0000",
):
) -> str | None:
data = {
"clientType": "1",
"appId": self._app_id,
Expand Down Expand Up @@ -554,9 +565,9 @@ def __init__(
password=password,
api_url=clouds[cloud_name]["api_url"],
)
self._session_id = None
self._session_id: str | None = None

def _make_general_data(self):
def _make_general_data(self) -> dict[str, Any]:
data = {
"src": self._app_id,
"format": "2",
Expand All @@ -570,7 +581,9 @@ def _make_general_data(self):
data.update({"sessionId": self._session_id})
return data

async def _api_request(self, endpoint: str, data: dict, header=None) -> dict | None:
async def _api_request(
self, endpoint: str, data: dict, header: dict[str, Any] | None = None
) -> dict[str, Any] | None:
header = header or {}
if not data.get("reqId"):
data.update({"reqId": token_hex(16)})
Expand Down Expand Up @@ -600,7 +613,7 @@ async def _api_request(self, endpoint: str, data: dict, header=None) -> dict | N
except Exception as e:
_LOGGER.warning(f"Midea cloud API error, url: {url}, error: {repr(e)}")
if int(response["errorCode"]) == 0 and "result" in response:
return response["result"]
return cast(dict[str, Any], response["result"])
return None

async def login(self) -> bool:
Expand All @@ -624,7 +637,9 @@ async def login(self) -> bool:
return True
return False

async def list_appliances(self, home_id) -> dict | None:
async def list_appliances(
self, home_id: str | None
) -> dict[int, dict[str, Any]] | None:
data = self._make_general_data()
if response := await self._api_request(
endpoint="/v1/appliance/user/list/get", data=data
Expand Down
15 changes: 8 additions & 7 deletions midealocal/packet_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import datetime
from typing import cast

from .security import LocalSecurity


class PacketBuilder:
def __init__(self, device_id: int, command):
self.command = None
def __init__(self, device_id: int, command: str) -> None:
self.command: str | None = None
self.security = LocalSecurity()
# aa20ac00000000000003418100ff03ff000200000000000000000000000006f274
# Init the packet with the header data.
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self, device_id: int, command):
self.packet[20:28] = device_id.to_bytes(8, "little")
self.command = command

def finalize(self, msg_type=1):
def finalize(self, msg_type: int = 1) -> bytearray:
if msg_type != 1:
self.packet[3] = 0x10
self.packet[6] = 0x7B
Expand All @@ -77,15 +78,15 @@ def finalize(self, msg_type=1):
self.packet.extend(self.encode32(self.packet))
return self.packet

def encode32(self, data: bytearray):
def encode32(self, data: bytearray) -> bytes:
return self.security.encode32_data(data)

@staticmethod
def checksum(data):
return (~sum(data) + 1) & 0xFF
def checksum(data: bytes) -> bytes:
return cast(bytes, (~sum(data) + 1) & 0xFF)

@staticmethod
def packet_time():
def packet_time() -> bytearray:
t = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:16]
b = bytearray()
for i in range(0, len(t), 2):
Expand Down

0 comments on commit 69e9d94

Please sign in to comment.