From 9015fba0b129db04c4b7a4fd996fa9ccb36e48fb Mon Sep 17 00:00:00 2001 From: David Rapan Date: Tue, 3 Sep 2024 03:24:28 +0200 Subject: [PATCH] feat: Verify serial during init discovery --- custom_components/solarman/__init__.py | 2 +- custom_components/solarman/discovery.py | 102 ++++++++++++------------ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/custom_components/solarman/__init__.py b/custom_components/solarman/__init__.py index a438738..1752361 100644 --- a/custom_components/solarman/__init__.py +++ b/custom_components/solarman/__init__.py @@ -44,7 +44,7 @@ async def async_setup_entry(hass: HomeAssistant, config: ConfigEntry) -> bool: ipaddr = IPv4Address(socket.gethostbyname(inverter_host)) if ipaddr.is_private: - inverter_discovery = InverterDiscovery(hass, inverter_host) + inverter_discovery = InverterDiscovery(hass, inverter_host, inverter_serial) if discovery: await inverter_discovery.discover() if (discovered_host := inverter_discovery.get_ip()): diff --git a/custom_components/solarman/discovery.py b/custom_components/solarman/discovery.py index a0f435b..26d835d 100644 --- a/custom_components/solarman/discovery.py +++ b/custom_components/solarman/discovery.py @@ -15,17 +15,15 @@ _LOGGER = logging.getLogger(__name__) class InverterDiscovery: - _port = DISCOVERY_PORT _message = DISCOVERY_MESSAGE.encode() - def __init__(self, hass: HomeAssistant, address = None): + def __init__(self, hass: HomeAssistant, ip = None, serial = None): self._hass = hass - self._address = address - self._ip = None - self._mac = None - self._serial = None + self._ip = ip + self._serial = serial + self._devices = {} - async def _discover(self, address = IP_BROADCAST, source = IP_ANY): + async def _discover(self, ip = IP_BROADCAST, wait = False, source = IP_ANY) -> dict: loop = asyncio.get_running_loop() try: @@ -38,23 +36,24 @@ async def _discover(self, address = IP_BROADCAST, source = IP_ANY): if source != IP_ANY: sock.bind((source, PORT_ANY)) - await loop.sock_sendto(sock, self._message, (address, self._port)) + await loop.sock_sendto(sock, self._message, (ip, DISCOVERY_PORT)) while True: try: recv = await loop.sock_recv(sock, DISCOVERY_RECV_MESSAGE_SIZE) data = recv.decode().split(',') if len(data) == 3: - self._ip = data[0] - self._mac = data[1] - self._serial = int(data[2]) - _LOGGER.debug(f"_discover: [{self._ip}, {self._mac}, {self._serial}]") + serial = int(data[2]) + yield serial, {"ip": data[0], "mac": data[1]} + _LOGGER.debug(f"_discover: [{data[0]}, {data[1]}, {serial}]") + if not wait: + return except (TimeoutError, socket.timeout): break except Exception as e: _LOGGER.exception(f"_discover: {format_exception(e)}") - async def _discover_all(self): + async def _discover_all(self) -> dict: _LOGGER.debug(f"_discover_all") adapters = await network.async_get_adapters(self._hass) @@ -67,67 +66,68 @@ async def _discover_all(self): _LOGGER.debug(f"_discover_all: Broadcasting on {net.with_prefixlen}") - await self._discover(str(IPv4Network(net, False).broadcast_address)) - #await self._discover(IP_BROADCAST, ipv4["address"]) - - if self._ip is not None: - return None + async for item in self._discover(str(IPv4Network(net, False).broadcast_address), True): + yield item async def discover(self): _LOGGER.debug(f"discover") - if self._address: - await self._discover(self._address) - - attempts_left = ACTION_ATTEMPTS - while self._ip is None and attempts_left > 0: - attempts_left -= 1 - - await self._discover_all() + devices = {} - if self._ip is None: - _LOGGER.debug(f"discover: {f'attempts left: {attempts_left}{'' if attempts_left > 0 else ', aborting.'}'}") - - async def discover_until_ok(self, x): - _LOGGER.debug(f"discover_until_ok") - - if self._address: - await self._discover(self._address) + if self._ip: + devices = {item[0]: item[1] async for item in self._discover(self._ip)} + if self._serial and self._serial != next(iter(devices)): + devices = {} attempts_left = ACTION_ATTEMPTS - while not self._ip and attempts_left > 0: + while len(devices) == 0 and attempts_left > 0: attempts_left -= 1 - await self._discover_all() + devices = {item[0]: item[1] async for item in self._discover_all()} - try: - await x(self._serial) - except: - self._ip = None - - if self._ip is None: + if len(devices) == 0: _LOGGER.debug(f"discover: {f'attempts left: {attempts_left}{'' if attempts_left > 0 else ', aborting.'}'}") + self._devices = devices + async def discover_ip(self): - if not self._ip: + if len(self._devices) == 0: await self.discover() - return self._ip + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return self._devices[item]["ip"] async def discover_mac(self): - if not self._mac: + if len(self._devices) == 0: await self.discover() - return self._mac + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return self._devices[item]["mac"] async def discover_serial(self): - if not self._serial: + if len(self._devices) == 0: await self.discover() - return self._serial + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return item def get_ip(self): - return self._ip + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return self._devices[item]["ip"] def get_mac(self): - return self._mac + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return self._devices[item]["mac"] def get_serial(self): - return self._serial + if len(self._devices) == 0: + return None + item = next(iter(self._devices)) + return item \ No newline at end of file