Skip to content

Commit

Permalink
Adjust the GitHub config flow (#105295)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludeeus authored Dec 27, 2023
1 parent 2d5176d commit b5012a9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 32 deletions.
81 changes: 50 additions & 31 deletions homeassistant/components/github/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from __future__ import annotations

import asyncio
from typing import Any
from contextlib import suppress
from typing import TYPE_CHECKING, Any

from aiogithubapi import (
GitHubAPI,
Expand All @@ -17,7 +18,7 @@
from homeassistant import config_entries
from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.data_entry_flow import FlowResult, UnknownFlow
from homeassistant.helpers.aiohttp_client import (
SERVER_SOFTWARE,
async_get_clientsession,
Expand Down Expand Up @@ -118,19 +119,27 @@ async def async_step_device(
"""Handle device steps."""

async def _wait_for_login() -> None:
# mypy is not aware that we can't get here without having these set already
assert self._device is not None
assert self._login_device is not None
if TYPE_CHECKING:
# mypy is not aware that we can't get here without having these set already
assert self._device is not None
assert self._login_device is not None

try:
response = await self._device.activation(
device_code=self._login_device.device_code
)
self._login = response.data

finally:
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(flow_id=self.flow_id)
)

async def _progress():
# If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow
with suppress(UnknownFlow):
await self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id
)

self.hass.async_create_task(_progress())

if not self._device:
self._device = GitHubDeviceAPI(
Expand All @@ -139,40 +148,43 @@ async def _wait_for_login() -> None:
**{"client_name": SERVER_SOFTWARE},
)

try:
response = await self._device.register()
self._login_device = response.data
except GitHubException as exception:
LOGGER.exception(exception)
return self.async_abort(reason="could_not_register")
try:
response = await self._device.register()
self._login_device = response.data
except GitHubException as exception:
LOGGER.exception(exception)
return self.async_abort(reason="could_not_register")

if not self.login_task:
if self.login_task is None:
self.login_task = self.hass.async_create_task(_wait_for_login())
return self.async_show_progress(
step_id="device",
progress_action="wait_for_device",
description_placeholders={
"url": OAUTH_USER_LOGIN,
"code": self._login_device.user_code,
},
)

try:
await self.login_task
except GitHubException as exception:
LOGGER.exception(exception)
return self.async_show_progress_done(next_step_id="could_not_register")
if self.login_task.done():
if self.login_task.exception():
return self.async_show_progress_done(next_step_id="could_not_register")
return self.async_show_progress_done(next_step_id="repositories")

if TYPE_CHECKING:
# mypy is not aware that we can't get here without having this set already
assert self._login_device is not None

return self.async_show_progress_done(next_step_id="repositories")
return self.async_show_progress(
step_id="device",
progress_action="wait_for_device",
description_placeholders={
"url": OAUTH_USER_LOGIN,
"code": self._login_device.user_code,
},
)

async def async_step_repositories(
self,
user_input: dict[str, Any] | None = None,
) -> FlowResult:
"""Handle repositories step."""

# mypy is not aware that we can't get here without having this set already
assert self._login is not None
if TYPE_CHECKING:
# mypy is not aware that we can't get here without having this set already
assert self._login is not None

if not user_input:
repositories = await get_repositories(self.hass, self._login.access_token)
Expand Down Expand Up @@ -208,6 +220,13 @@ def async_get_options_flow(
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)

@callback
def async_remove(self) -> None:
"""Handle remove handler callback."""
if self.login_task and not self.login_task.done():
# Clean up login task if it's still running
self.login_task.cancel()


class OptionsFlowHandler(config_entries.OptionsFlow):
"""Handle a option flow for GitHub."""
Expand Down
41 changes: 40 additions & 1 deletion tests/components/github/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import AsyncMock, MagicMock, patch

from aiogithubapi import GitHubException
import pytest

from homeassistant import config_entries
from homeassistant.components.github.config_flow import get_repositories
Expand All @@ -12,7 +13,7 @@
)
from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.data_entry_flow import FlowResultType, UnknownFlow

from .common import MOCK_ACCESS_TOKEN

Expand Down Expand Up @@ -126,6 +127,44 @@ async def test_flow_with_activation_failure(
assert result["step_id"] == "could_not_register"


async def test_flow_with_remove_while_activating(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test flow with user canceling while activating."""
aioclient_mock.post(
"https://github.com/login/device/code",
json={
"device_code": "3584d83530557fdd1f46af8289938c8ef79f9dc5",
"user_code": "WDJB-MJHT",
"verification_uri": "https://github.com/login/device",
"expires_in": 900,
"interval": 5,
},
headers={"Content-Type": "application/json"},
)
aioclient_mock.post(
"https://github.com/login/oauth/access_token",
json={"error": "authorization_pending"},
headers={"Content-Type": "application/json"},
)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_USER},
)
assert result["step_id"] == "device"
assert result["type"] == FlowResultType.SHOW_PROGRESS

assert hass.config_entries.flow.async_get(result["flow_id"])

# Simulate user canceling the flow
hass.config_entries.flow._async_remove_flow_progress(result["flow_id"])
await hass.async_block_till_done()

with pytest.raises(UnknownFlow):
hass.config_entries.flow.async_get(result["flow_id"])


async def test_already_configured(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
Expand Down

0 comments on commit b5012a9

Please sign in to comment.