diff --git a/pyflowlauncher/event.py b/pyflowlauncher/event.py index c156323..ebdf518 100644 --- a/pyflowlauncher/event.py +++ b/pyflowlauncher/event.py @@ -5,30 +5,36 @@ class EventHandler: def __init__(self): - self._methods = {} + self._events = {} self._handlers = {} def _get_callable_name(self, method: Callable[..., Any]): return getattr(method, '__name__', method.__class__.__name__).lower() - def add_method(self, method: Callable[..., Any], *, name=None) -> str: - key = name or self._get_callable_name(method) - self._methods[key] = method + def add_event(self, event: Callable[..., Any], *, name=None) -> str: + key = name or self._get_callable_name(event) + self._events[key] = event return key - def add_methods(self, methods: Iterable[Callable[..., Any]]): - for method in methods: - self.add_method(method) + def add_events(self, events: Iterable[Callable[..., Any]]): + for event in events: + self.add_event(event) def add_exception_handler(self, exception: Type[Exception], handler: Callable[..., Any]): self._handlers[exception] = handler - async def __call__(self, method: str, *args, **kwargs): + def _call_event(self, event: str, *args, **kwargs) -> Callable[..., Any] | None: + return self._events[event](*args, **kwargs) + + async def _await_maybe(self, result: Any) -> Any: + if asyncio.iscoroutine(result): + return await result + return result + + async def trigger_event(self, event: str, *args, **kwargs) -> Any: try: - result = self._methods[method](*args, **kwargs) - if asyncio.iscoroutine(result): - return await result - return result + result = self._call_event(event, *args, **kwargs) + return await self._await_maybe(result) except Exception as e: handler = self._handlers.get(type(e), None) if handler: diff --git a/pyflowlauncher/plugin.py b/pyflowlauncher/plugin.py index dca701e..d28f070 100644 --- a/pyflowlauncher/plugin.py +++ b/pyflowlauncher/plugin.py @@ -29,16 +29,16 @@ def __init__(self, methods: list[Method] | None = None) -> None: def add_method(self, method: Method) -> str: """Add a method to the event handler.""" - return self._event_handler.add_method(method) + return self._event_handler.add_event(method) def add_methods(self, methods: Iterable[Method]) -> None: - self._event_handler.add_methods(methods) + self._event_handler.add_events(methods) def on_method(self, method: Method) -> Method: @wraps(method) def wrapper(*args, **kwargs): return method(*args, **kwargs) - self._event_handler.add_method(wrapper) + self._event_handler.add_event(wrapper) return wrapper def method(self, method: Method) -> Method: diff --git a/tests/test_events.py b/tests/test_events.py index cf96923..e160df7 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -20,28 +20,28 @@ def except_method(): def test_add_method(): handler = EventHandler() - handler.add_method(temp_method1) - assert handler._methods == {"temp_method1": temp_method1} + handler.add_event(temp_method1) + assert handler._events == {"temp_method1": temp_method1} def test_add_methods(): handler = EventHandler() - handler.add_methods([temp_method1, temp_method2]) - assert handler._methods == {"temp_method1": temp_method1, "temp_method2": temp_method2} + handler.add_events([temp_method1, temp_method2]) + assert handler._events == {"temp_method1": temp_method1, "temp_method2": temp_method2} @pytest.mark.asyncio async def test_call(): handler = EventHandler() - handler.add_method(temp_method1) - assert await handler("temp_method1") is None + handler.add_event(temp_method1) + assert await handler.trigger_event("temp_method1") is None @pytest.mark.asyncio async def test_call_async(): handler = EventHandler() - handler.add_method(async_temp_method3) - assert await handler("async_temp_method3") is None + handler.add_event(async_temp_method3) + assert await handler.trigger_event("async_temp_method3") is None def test_add_exception_handler(): @@ -53,6 +53,6 @@ def test_add_exception_handler(): @pytest.mark.asyncio async def test_call_exception(): handler = EventHandler() - handler.add_method(except_method) + handler.add_event(except_method) with pytest.raises(Exception): - await handler("except_method") + await handler.trigger_event("except_method") diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 547da24..6dba757 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -17,13 +17,13 @@ def query(query: str): def test_add_method(): plugin = Plugin() plugin.add_method(temp_method1) - assert plugin._event_handler._methods == {'temp_method1': temp_method1} + assert plugin._event_handler._events == {'temp_method1': temp_method1} def test_add_methods(): plugin = Plugin() plugin.add_methods([temp_method1, temp_method2]) - assert plugin._event_handler._methods == {'temp_method1': temp_method1, 'temp_method2': temp_method2} + assert plugin._event_handler._events == {'temp_method1': temp_method1, 'temp_method2': temp_method2} def test_settings():