From a017e7af46130266226f6c9a100b05703dd6cb5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Mind=C3=AAllo=20de=20Andrade?= Date: Thu, 1 Aug 2024 21:17:54 -0300 Subject: [PATCH] fix(cli): authenticate to get keys (#256) ## Summary by CodeRabbit - **New Features** - Enhanced authentication logic for cloud key retrieval to ensure valid session checks before accessing device-specific keys. - **Bug Fixes** - Improved error handling to prevent attempts to access cloud keys when authentication fails, reducing potential API misuse. - **Tests** - Added an asynchronous test for `_get_keys` to validate key retrieval and caching behavior under various scenarios, improving overall test coverage. --- midealocal/cli.py | 8 +++++++- tests/cli_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/midealocal/cli.py b/midealocal/cli.py index b834584b..76f9c093 100644 --- a/midealocal/cli.py +++ b/midealocal/cli.py @@ -55,8 +55,14 @@ async def _get_cloud(self) -> MideaCloud: async def _get_keys(self, device_id: int) -> dict[int, dict[str, Any]]: cloud = await self._get_cloud() - cloud_keys = await cloud.get_cloud_keys(device_id) default_keys = await cloud.get_default_keys() + if not await cloud.login(): + _LOGGER.warning( + "Failed to authenticate to the cloud. Using only default keys.", + ) + return default_keys + cloud_keys = await cloud.get_cloud_keys(device_id) + return {**cloud_keys, **default_keys} async def discover(self) -> MideaDevice | None: diff --git a/tests/cli_test.py b/tests/cli_test.py index 0f87045c..63268b09 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -60,6 +60,46 @@ async def test_get_cloud(self) -> None: with pytest.raises(ElementMissing): await self.cli._get_cloud() + async def test_get_keys(self) -> None: + """Test get keys.""" + mock_cloud = AsyncMock() + with ( + patch("midealocal.cli.get_midea_cloud", return_value=mock_cloud), + patch.object( + mock_cloud, + "get_default_keys", + return_value={99: {"key": "key99", "token": "token99"}}, + ) as mock_default_keys, + patch.object( + mock_cloud, + "get_cloud_keys", + return_value={ + 0: {"key": "key0", "token": "token0"}, + 1: {"key": "key1", "token": "token1"}, + }, + ) as mock_cloud_keys, + patch.object(mock_cloud, "login", side_effect=[True, False]), + ): + keys = await self.cli._get_keys(0) + assert len(keys) == 3 + assert keys[0]["key"] == "key0" + assert keys[1]["key"] == "key1" + assert keys[99]["key"] == "key99" + assert keys[0]["token"] == "token0" + assert keys[1]["token"] == "token1" + assert keys[99]["token"] == "token99" + mock_default_keys.assert_called_once() + mock_default_keys.reset_mock() + mock_cloud_keys.assert_called_once_with(0) + mock_cloud_keys.reset_mock() + + keys = await self.cli._get_keys(0) + assert len(keys) == 1 + assert keys[99]["key"] == "key99" + assert keys[99]["token"] == "token99" + mock_default_keys.assert_called_once() + mock_cloud_keys.assert_not_called() + async def test_discover(self) -> None: """Test discover.""" mock_device = {