Skip to content

Commit

Permalink
WIP async singleton with HassKey
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Dec 13, 2024
1 parent bf9788b commit c27237f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 31 deletions.
21 changes: 12 additions & 9 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
language as language_util,
ulid as ulid_util,
)
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.limited_size_dict import LimitedSizeDict

from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
Expand Down Expand Up @@ -90,6 +91,8 @@
("tts_engine", "tts_language"),
)

KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)


def validate_language(data: dict[str, Any]) -> Any:
"""Validate language settings."""
Expand Down Expand Up @@ -247,7 +250,7 @@ async def async_create_default_pipeline(
The default pipeline will use the homeassistant conversation agent and the
specified stt / tts engines.
"""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = _async_resolve_default_pipeline_settings(
hass,
Expand Down Expand Up @@ -282,7 +285,7 @@ def _async_get_pipeline_from_conversation_entity(
@callback
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]

if pipeline_id is None:
# A pipeline was not specified, use the preferred one
Expand All @@ -305,7 +308,7 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
@callback
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]

return list(pipeline_data.pipeline_store.data.values())

Expand All @@ -328,7 +331,7 @@ async def async_update_pipeline(
prefer_local_intents: bool | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]

updates: dict[str, Any] = pipeline.to_json()
updates.pop("id")
Expand Down Expand Up @@ -586,7 +589,7 @@ def __post_init__(self) -> None:
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)

pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
if self.pipeline.id not in pipeline_data.pipeline_debug:
pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_RUNS
Expand Down Expand Up @@ -614,7 +617,7 @@ def __eq__(self, other: object) -> bool:
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
self.event_callback(event)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]:
# This run has been evicted from the logged pipeline runs already
return
Expand Down Expand Up @@ -649,7 +652,7 @@ async def end(self) -> None:
)
)

pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
pipeline_data.pipeline_runs.remove_run(self)

async def prepare_wake_word_detection(self) -> None:
Expand Down Expand Up @@ -1207,7 +1210,7 @@ def _capture_chunk(self, audio_bytes: bytes | None) -> None:
return

# Forward to device audio capture
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
if audio_queue is None:
return
Expand Down Expand Up @@ -1858,7 +1861,7 @@ async def _async_migrate_func(
return old_data


@singleton(DOMAIN)
@singleton(KEY_ASSIST_PIPELINE, async_=True)
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Expand Down
12 changes: 5 additions & 7 deletions homeassistant/components/assist_pipeline/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state

from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .const import OPTION_PREFERRED
from .pipeline import KEY_ASSIST_PIPELINE, AssistDevice
from .vad import VadSensitivity


Expand All @@ -30,7 +30,7 @@ def get_chosen_pipeline(
if state is None or state.state == OPTION_PREFERRED:
return None

pipeline_store: PipelineStorageCollection = hass.data[DOMAIN].pipeline_store
pipeline_store = hass.data[KEY_ASSIST_PIPELINE].pipeline_store
return next(
(item.id for item in pipeline_store.async_items() if item.name == state.state),
None,
Expand Down Expand Up @@ -80,7 +80,7 @@ async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()

pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
pipeline_store = pipeline_data.pipeline_store
self.async_on_remove(
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
Expand Down Expand Up @@ -116,9 +116,7 @@ async def _pipelines_updated(
@callback
def _update_options(self) -> None:
"""Handle pipeline update."""
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
pipeline_store = self.hass.data[KEY_ASSIST_PIPELINE].pipeline_store
options = [OPTION_PREFERRED]
options.extend(sorted(item.name for item in pipeline_store.async_items()))
self._attr_options = options
Expand Down
11 changes: 5 additions & 6 deletions homeassistant/components/assist_pipeline/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
from .const import (
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
EVENT_RECORDING,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
)
from .error import PipelineNotFound
from .pipeline import (
KEY_ASSIST_PIPELINE,
AudioSettings,
DeviceAudioQueue,
PipelineData,
PipelineError,
PipelineEvent,
PipelineEventType,
Expand Down Expand Up @@ -284,7 +283,7 @@ def websocket_list_runs(
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_id = msg["pipeline_id"]

if pipeline_id not in pipeline_data.pipeline_debug:
Expand Down Expand Up @@ -320,7 +319,7 @@ def websocket_list_devices(
msg: dict[str, Any],
) -> None:
"""List assist devices."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
ent_reg = er.async_get(hass)
connection.send_result(
msg["id"],
Expand Down Expand Up @@ -351,7 +350,7 @@ def websocket_get_run(
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]

Expand Down Expand Up @@ -456,7 +455,7 @@ async def websocket_device_capture(
msg: dict[str, Any],
) -> None:
"""Capture raw audio from a satellite device and forward to client."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
device_id = msg["device_id"]

# Number of seconds to record audio in wall clock time
Expand Down
9 changes: 6 additions & 3 deletions homeassistant/components/esphome/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.util.hass_dict import HassKey

from .const import DOMAIN
from .coordinator import ESPHomeDashboardCoordinator

_LOGGER = logging.getLogger(__name__)


KEY_DASHBOARD_MANAGER = "esphome_dashboard_manager"
KEY_DASHBOARD_MANAGER: HassKey[ESPHomeDashboardManager] = HassKey(
"esphome_dashboard_manager"
)

STORAGE_KEY = "esphome.dashboard"
STORAGE_VERSION = 1
Expand All @@ -33,7 +36,7 @@ async def async_setup(hass: HomeAssistant) -> None:
await async_get_or_create_dashboard_manager(hass)


@singleton(KEY_DASHBOARD_MANAGER)
@singleton(KEY_DASHBOARD_MANAGER, async_=True)
async def async_get_or_create_dashboard_manager(
hass: HomeAssistant,
) -> ESPHomeDashboardManager:
Expand Down Expand Up @@ -140,7 +143,7 @@ def async_get_dashboard(hass: HomeAssistant) -> ESPHomeDashboardCoordinator | No
where manager can be an asyncio.Event instead of the actual manager
because the singleton decorator is not yet done.
"""
manager: ESPHomeDashboardManager | None = hass.data.get(KEY_DASHBOARD_MANAGER)
manager = hass.data.get(KEY_DASHBOARD_MANAGER)
return manager.async_get() if manager else None


Expand Down
27 changes: 21 additions & 6 deletions homeassistant/helpers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from collections.abc import Callable, Coroutine
import functools
from typing import Any, cast, overload
from typing import Any, Literal, cast, overload

from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey

type _FuncType[_T] = Callable[[HomeAssistant], _T]
type _Coro[_T] = Coroutine[Any, Any, _T]


@overload
def singleton[_T](
data_key: HassKey[_T], *, async_: Literal[True]
) -> Callable[[_FuncType[_Coro[_T]]], _FuncType[_Coro[_T]]]: ...


@overload
Expand All @@ -24,13 +31,21 @@ def singleton[_T](
def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...


def singleton[_T](data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
def singleton[_T](
data_key: Any, *, async_: bool = False
) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
"""Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled.
"""

def wrapper(func: _FuncType[_T]) -> _FuncType[_T]:
@overload
def wrapper(func: _FuncType[_Coro[_T]]) -> _FuncType[_Coro[_T]]: ...

@overload
def wrapper(func: _FuncType[_T]) -> _FuncType[_T]: ...

def wrapper(func: _FuncType[_Coro[_T] | _T]) -> _FuncType[_Coro[_T] | _T]: # type: ignore[misc]
"""Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func):

Expand All @@ -46,7 +61,7 @@ def wrapped(hass: HomeAssistant) -> _T:

@bind_hass
@functools.wraps(func)
async def async_wrapped(hass: HomeAssistant) -> Any:
async def async_wrapped(hass: HomeAssistant) -> _T:
if data_key not in hass.data:
evt = hass.data[data_key] = asyncio.Event()
result = await func(hass)
Expand All @@ -62,6 +77,6 @@ async def async_wrapped(hass: HomeAssistant) -> Any:

return cast(_T, obj_or_evt)

return async_wrapped # type: ignore[return-value]
return async_wrapped

return wrapper

0 comments on commit c27237f

Please sign in to comment.