diff --git a/homeassistant/components/github/config_flow.py b/homeassistant/components/github/config_flow.py index 5e223483e2e9b..c90caf0fc894c 100644 --- a/homeassistant/components/github/config_flow.py +++ b/homeassistant/components/github/config_flow.py @@ -2,7 +2,8 @@ from __future__ import annotations import asyncio -from typing import Any +from contextlib import suppress +from typing import TYPE_CHECKING, Any from aiogithubapi import ( GitHubAPI, @@ -17,7 +18,7 @@ from homeassistant import config_entries from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult +from homeassistant.data_entry_flow import FlowResult, UnknownFlow from homeassistant.helpers.aiohttp_client import ( SERVER_SOFTWARE, async_get_clientsession, @@ -118,19 +119,27 @@ async def async_step_device( """Handle device steps.""" async def _wait_for_login() -> None: - # mypy is not aware that we can't get here without having these set already - assert self._device is not None - assert self._login_device is not None + if TYPE_CHECKING: + # mypy is not aware that we can't get here without having these set already + assert self._device is not None + assert self._login_device is not None try: response = await self._device.activation( device_code=self._login_device.device_code ) self._login = response.data + finally: - self.hass.async_create_task( - self.hass.config_entries.flow.async_configure(flow_id=self.flow_id) - ) + + async def _progress(): + # If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow + with suppress(UnknownFlow): + await self.hass.config_entries.flow.async_configure( + flow_id=self.flow_id + ) + + self.hass.async_create_task(_progress()) if not self._device: self._device = GitHubDeviceAPI( @@ -139,31 +148,33 @@ async def _wait_for_login() -> None: **{"client_name": SERVER_SOFTWARE}, ) - try: - response = await self._device.register() - self._login_device = response.data - except GitHubException as exception: - LOGGER.exception(exception) - return self.async_abort(reason="could_not_register") + try: + response = await self._device.register() + self._login_device = response.data + except GitHubException as exception: + LOGGER.exception(exception) + return self.async_abort(reason="could_not_register") - if not self.login_task: + if self.login_task is None: self.login_task = self.hass.async_create_task(_wait_for_login()) - return self.async_show_progress( - step_id="device", - progress_action="wait_for_device", - description_placeholders={ - "url": OAUTH_USER_LOGIN, - "code": self._login_device.user_code, - }, - ) - try: - await self.login_task - except GitHubException as exception: - LOGGER.exception(exception) - return self.async_show_progress_done(next_step_id="could_not_register") + if self.login_task.done(): + if self.login_task.exception(): + return self.async_show_progress_done(next_step_id="could_not_register") + return self.async_show_progress_done(next_step_id="repositories") + + if TYPE_CHECKING: + # mypy is not aware that we can't get here without having this set already + assert self._login_device is not None - return self.async_show_progress_done(next_step_id="repositories") + return self.async_show_progress( + step_id="device", + progress_action="wait_for_device", + description_placeholders={ + "url": OAUTH_USER_LOGIN, + "code": self._login_device.user_code, + }, + ) async def async_step_repositories( self, @@ -171,8 +182,9 @@ async def async_step_repositories( ) -> FlowResult: """Handle repositories step.""" - # mypy is not aware that we can't get here without having this set already - assert self._login is not None + if TYPE_CHECKING: + # mypy is not aware that we can't get here without having this set already + assert self._login is not None if not user_input: repositories = await get_repositories(self.hass, self._login.access_token) @@ -208,6 +220,13 @@ def async_get_options_flow( """Get the options flow for this handler.""" return OptionsFlowHandler(config_entry) + @callback + def async_remove(self) -> None: + """Handle remove handler callback.""" + if self.login_task and not self.login_task.done(): + # Clean up login task if it's still running + self.login_task.cancel() + class OptionsFlowHandler(config_entries.OptionsFlow): """Handle a option flow for GitHub.""" diff --git a/tests/components/github/test_config_flow.py b/tests/components/github/test_config_flow.py index a86e1d134aa11..8d61eca1ab1bc 100644 --- a/tests/components/github/test_config_flow.py +++ b/tests/components/github/test_config_flow.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from aiogithubapi import GitHubException +import pytest from homeassistant import config_entries from homeassistant.components.github.config_flow import get_repositories @@ -12,7 +13,7 @@ ) from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResultType +from homeassistant.data_entry_flow import FlowResultType, UnknownFlow from .common import MOCK_ACCESS_TOKEN @@ -126,6 +127,44 @@ async def test_flow_with_activation_failure( assert result["step_id"] == "could_not_register" +async def test_flow_with_remove_while_activating( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, +) -> None: + """Test flow with user canceling while activating.""" + aioclient_mock.post( + "https://github.com/login/device/code", + json={ + "device_code": "3584d83530557fdd1f46af8289938c8ef79f9dc5", + "user_code": "WDJB-MJHT", + "verification_uri": "https://github.com/login/device", + "expires_in": 900, + "interval": 5, + }, + headers={"Content-Type": "application/json"}, + ) + aioclient_mock.post( + "https://github.com/login/oauth/access_token", + json={"error": "authorization_pending"}, + headers={"Content-Type": "application/json"}, + ) + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_USER}, + ) + assert result["step_id"] == "device" + assert result["type"] == FlowResultType.SHOW_PROGRESS + + assert hass.config_entries.flow.async_get(result["flow_id"]) + + # Simulate user canceling the flow + hass.config_entries.flow._async_remove_flow_progress(result["flow_id"]) + await hass.async_block_till_done() + + with pytest.raises(UnknownFlow): + hass.config_entries.flow.async_get(result["flow_id"]) + + async def test_already_configured( hass: HomeAssistant, mock_config_entry: MockConfigEntry,