Skip to content

Commit

Permalink
chore: typing part 1 (#26)
Browse files Browse the repository at this point in the history
* chore: typing part 1

* chore: fix return value

* chore: import alias
  • Loading branch information
chemelli74 authored May 27, 2024
1 parent 1546db1 commit 1992fdd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 53 deletions.
15 changes: 10 additions & 5 deletions midealocal/discover.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import socket
from typing import Any
import xml.etree.ElementTree as ET
from ipaddress import IPv4Network

Expand Down Expand Up @@ -148,7 +149,9 @@
)


def discover(discover_type=None, ip_address=None):
def discover(
discover_type: list | None = None, ip_address: list | None = None
) -> dict[int, dict[str, Any]]:
if discover_type is None:
discover_type = []
security = LocalSecurity()
Expand Down Expand Up @@ -200,6 +203,7 @@ def discover(discover_type=None, ip_address=None):
protocol = 1
root = ET.fromstring(data.decode(encoding="utf-8", errors="replace"))
child = root.find("body/device")
assert child
m = child.attrib
port, sn, device_type = (
int(m["port"]),
Expand Down Expand Up @@ -237,18 +241,19 @@ def discover(discover_type=None, ip_address=None):
return found_devices


def get_id_from_response(response):
def get_id_from_response(response: bytearray) -> int:
if response[64:-16][:6].hex() == "3c3f786d6c20":
xml = response[64:-16]
root = ET.fromstring(xml.decode(encoding="utf-8", errors="replace"))
child = root.find("smartDevice")
assert child
m = child.attrib
return int.from_bytes(bytearray.fromhex(m["devId"]), "little")
else:
return 0


def bytes2port(paramArrayOfbyte):
def bytes2port(paramArrayOfbyte: bytes | None) -> int:
if paramArrayOfbyte is None:
return 0
b, i = 0, 0
Expand All @@ -262,7 +267,7 @@ def bytes2port(paramArrayOfbyte):
return i


def get_device_info(device_ip, device_port: int):
def get_device_info(device_ip: str, device_port: int) -> bytearray:
response = bytearray(0)
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
Expand All @@ -284,7 +289,7 @@ def get_device_info(device_ip, device_port: int):
return response


def enum_all_broadcast():
def enum_all_broadcast() -> list:
nets = []
adapters = ifaddr.get_adapters()
for adapter in adapters:
Expand Down
128 changes: 80 additions & 48 deletions midealocal/security.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import hmac
from hashlib import md5, sha256
from typing import Any
from typing import Any, cast
from urllib.parse import unquote_plus, urlencode, urlparse

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from Crypto.Util.strxor import strxor
from Crypto.Cipher import AES

Buffer = bytes | bytearray | memoryview # alias from Crypto.Cipher.AES

MSGTYPE_HANDSHAKE_REQUEST = 0x0
MSGTYPE_HANDSHAKE_RESPONSE = 0x1
Expand All @@ -15,39 +17,48 @@


class CloudSecurity:
def __init__(self, login_key, iot_key, hmac_key, fixed_key=None, fixed_iv=None):
def __init__(
self,
login_key: str,
iot_key: str | None,
hmac_key: str | None,
fixed_key: int | None = None,
fixed_iv: int | None = None,
) -> None:
self._login_key = login_key
self._iot_key = iot_key
self._hmac_key = hmac_key
self._aes_key = None
self._aes_iv = None
self._aes_key: bytes
self._aes_iv: bytes
self._fixed_key = format(fixed_key, "x").encode("ascii") if fixed_key else None
self._fixed_iv = format(fixed_iv, "x").encode("ascii") if fixed_iv else None

def sign(self, url: str, data: Any, random: str) -> str:
msg = self._iot_key
msg += str(data)
def sign(self, url: str, data: str, random: str) -> str | None:
msg: str = self._iot_key or ""
msg += data
msg += random
if not self._hmac_key:
return None
sign = hmac.new(self._hmac_key.encode("ascii"), msg.encode("ascii"), sha256)
return sign.hexdigest()

def encrypt_password(self, login_id, data):
def encrypt_password(self, login_id: str, data: str) -> str:
m = sha256()
m.update(data.encode("ascii"))
login_hash = login_id + m.hexdigest() + self._login_key
m = sha256()
m.update(login_hash.encode("ascii"))
return m.hexdigest()

def encrypt_iam_password(self, login_id, data) -> str:
def encrypt_iam_password(self, login_id: str, data: str) -> str:
raise NotImplementedError

@staticmethod
def get_deviceid(username):
def get_deviceid(username: str) -> str:
return sha256(f"Hello, {username}!".encode("ascii")).digest().hex()[:16]

@staticmethod
def get_udp_id(appliance_id, method=0):
def get_udp_id(appliance_id: Any, method: int = 0) -> str | None:
if method == 0:
bytes_id = bytes(reversed(appliance_id.to_bytes(8, "big")))
elif method == 1:
Expand All @@ -61,21 +72,23 @@ def get_udp_id(appliance_id, method=0):
data[i] ^= data[i + 16]
return data[0:16].hex()

def set_aes_keys(self, key, iv):
def set_aes_keys(self, key: bytes | str, iv: bytes | str) -> None:
if isinstance(key, str):
key = key.encode("ascii")
if isinstance(iv, str):
iv = iv.encode("ascii")
self._aes_key = key
self._aes_iv = iv

def aes_encrypt_with_fixed_key(self, data):
def aes_encrypt_with_fixed_key(self, data: str) -> bytes:
return self.aes_encrypt(data, self._fixed_key, self._fixed_iv)

def aes_decrypt_with_fixed_key(self, data):
def aes_decrypt_with_fixed_key(self, data: str) -> str:
return self.aes_decrypt(data, self._fixed_key, self._fixed_iv)

def aes_encrypt(self, data, key=None, iv=None):
def aes_encrypt(
self, data: str | bytes, key: bytes | None = None, iv: bytes | None = None
) -> bytes:
if key is not None:
aes_key = key
aes_iv = iv
Expand All @@ -87,11 +100,15 @@ def aes_encrypt(self, data, key=None, iv=None):
if isinstance(data, str):
data = bytes.fromhex(data)
if aes_iv is None: # ECB
return AES.new(aes_key, AES.MODE_ECB).encrypt(pad(data, 16))
else: # CBC
return AES.new(aes_key, AES.MODE_CBC, iv=aes_iv).encrypt(pad(data, 16))
return cast(bytes, AES.new(aes_key, AES.MODE_ECB).encrypt(pad(data, 16)))
# CBC
return cast(
bytes, AES.new(aes_key, AES.MODE_CBC, iv=aes_iv).encrypt(pad(data, 16))
)

def aes_decrypt(self, data, key=None, iv=None):
def aes_decrypt(
self, data: str | bytes, key: bytes | None = None, iv: bytes | None = None
) -> str:
if key is not None:
aes_key = key
aes_iv = iv
Expand All @@ -103,20 +120,27 @@ def aes_decrypt(self, data, key=None, iv=None):
if isinstance(data, str):
data = bytes.fromhex(data)
if aes_iv is None: # ECB
return unpad(
AES.new(aes_key, AES.MODE_ECB).decrypt(data), len(aes_key)
).decode()
return cast(
str,
unpad(
AES.new(aes_key, AES.MODE_ECB).decrypt(data), len(aes_key)
).decode(),
)
else: # CBC
return unpad(
AES.new(aes_key, AES.MODE_CBC, iv=aes_iv).decrypt(data), len(aes_key)
).decode()
return cast(
str,
unpad(
AES.new(aes_key, AES.MODE_CBC, iv=aes_iv).decrypt(data),
len(aes_key),
).decode(),
)


class MeijuCloudSecurity(CloudSecurity):
def __init__(self, login_key, iot_key, hmac_key):
def __init__(self, login_key: str, iot_key: str, hmac_key: str) -> None:
super().__init__(login_key, iot_key, hmac_key, 10864842703515613082)

def encrypt_iam_password(self, login_id, data) -> str:
def encrypt_iam_password(self, login_id: str, data: str) -> str:
md = md5()
md.update(data.encode("ascii"))
md_second = md5()
Expand All @@ -125,12 +149,12 @@ def encrypt_iam_password(self, login_id, data) -> str:


class MSmartCloudSecurity(CloudSecurity):
def __init__(self, login_key, iot_key, hmac_key):
def __init__(self, login_key: str, iot_key: str, hmac_key: str) -> None:
super().__init__(
login_key, iot_key, hmac_key, 13101328926877700970, 16429062708050928556
)

def encrypt_iam_password(self, login_id, data) -> str:
def encrypt_iam_password(self, login_id: str, data: str) -> str:
md = md5()
md.update(data.encode("ascii"))
md_second = md5()
Expand All @@ -140,7 +164,9 @@ def encrypt_iam_password(self, login_id, data) -> str:
sha.update(login_hash.encode("ascii"))
return sha.hexdigest()

def set_aes_keys(self, encrypted_key, encrypted_iv):
def set_aes_keys(
self, encrypted_key: str | bytes, encrypted_iv: str | bytes
) -> None:
key_digest = sha256(self._login_key.encode("ascii")).hexdigest()
tmp_key = key_digest[:16].encode("ascii")
tmp_iv = key_digest[16:32].encode("ascii")
Expand All @@ -149,7 +175,7 @@ def set_aes_keys(self, encrypted_key, encrypted_iv):


class MideaAirSecurity(CloudSecurity):
def __init__(self, login_key):
def __init__(self, login_key: str) -> None:
super().__init__(login_key, None, None)

def sign(self, url: str, data: Any, random: str) -> str:
Expand All @@ -160,8 +186,11 @@ def sign(self, url: str, data: Any, random: str) -> str:


class LocalSecurity:
def __init__(self):
self.blockSize = 16
"""Evaluate local security."""

def __init__(self) -> None:
"""Initialize security."""
self.block_size = 16
self.iv = b"\0" * 16
self.aes_key = bytes.fromhex(
format(141661095494369103254425781617665632877, "x")
Expand All @@ -172,31 +201,34 @@ def __init__(self):
"x",
)
)
self._tcp_key = None
self._tcp_key: bytes
self._request_count = 0
self._response_count = 0

def aes_decrypt(self, raw):
def aes_decrypt(self, raw: bytes) -> bytes:
try:
return unpad(
AES.new(self.aes_key, AES.MODE_ECB).decrypt(bytearray(raw)), 16
return cast(
bytes,
unpad(AES.new(self.aes_key, AES.MODE_ECB).decrypt(bytearray(raw)), 16),
)
except ValueError:
return bytearray(0)

def aes_encrypt(self, raw):
return AES.new(self.aes_key, AES.MODE_ECB).encrypt(bytearray(pad(raw, 16)))
def aes_encrypt(self, raw: bytes) -> bytes:
return cast(
bytes, AES.new(self.aes_key, AES.MODE_ECB).encrypt(bytearray(pad(raw, 16)))
)

def aes_cbc_decrypt(self, raw, key):
return AES.new(key=key, mode=AES.MODE_CBC, iv=self.iv).decrypt(raw)
def aes_cbc_decrypt(self, raw: bytes, key: Buffer) -> bytes:
return cast(bytes, AES.new(key=key, mode=AES.MODE_CBC, iv=self.iv).decrypt(raw))

def aes_cbc_encrypt(self, raw, key):
return AES.new(key=key, mode=AES.MODE_CBC, iv=self.iv).encrypt(raw)
def aes_cbc_encrypt(self, raw: bytes, key: Buffer) -> bytes:
return cast(bytes, AES.new(key=key, mode=AES.MODE_CBC, iv=self.iv).encrypt(raw))

def encode32_data(self, raw):
def encode32_data(self, raw: bytes) -> bytes:
return md5(raw + self.salt).digest()

def tcp_key(self, response, key):
def tcp_key(self, response: bytes, key: Buffer) -> bytes:
if response == b"ERROR":
raise Exception("authentication failed")
if len(response) != 64:
Expand All @@ -211,7 +243,7 @@ def tcp_key(self, response, key):
self._response_count = 0
return self._tcp_key

def encode_8370(self, data, msgtype):
def encode_8370(self, data: bytes, msgtype: int) -> bytes:
header = bytearray([0x83, 0x70])
size, padding = len(data), 0
if msgtype in (MSGTYPE_ENCRYPTED_RESPONSE, MSGTYPE_ENCRYPTED_REQUEST):
Expand All @@ -230,7 +262,7 @@ def encode_8370(self, data, msgtype):
data = self.aes_cbc_encrypt(raw=data, key=self._tcp_key) + sign
return header + data

def decode_8370(self, data):
def decode_8370(self, data: bytes) -> tuple[list, bytes]:
if len(data) < 6:
return [], data
header = data[:6]
Expand Down

0 comments on commit 1992fdd

Please sign in to comment.