diff --git a/server/core/service.py b/server/core/service.py index ed2be8118..52bf6af43 100644 --- a/server/core/service.py +++ b/server/core/service.py @@ -1,5 +1,5 @@ import re -from typing import Dict, List +from typing import Any, Dict, List, Optional from .dependency_injector import DependencyInjector @@ -7,28 +7,22 @@ DependencyGraph = Dict[str, List[str]] -class ServiceMeta(type): - """ - For tracking which Services have been defined. - """ - - # Mapping from parameter name to class - services: Dict[str, type] = {} +service_registry: Dict[str, type] = {} - def __new__(cls, name, bases, attrs): - klass = type.__new__(cls, name, bases, attrs) - if name != "Service": - arg_name = snake_case(name) - cls.services[arg_name] = klass - return klass - -class Service(metaclass=ServiceMeta): +class Service(): """ All services should inherit from this class. Services are singleton objects which manage some server task. """ + def __init_subclass__(cls, name: Optional[str] = None, **kwargs: Any): + """ + For tracking which services have been defined. + """ + super().__init_subclass__(**kwargs) + arg_name = name or snake_case(cls.__name__) + service_registry[arg_name] = cls async def initialize(self) -> None: """ @@ -51,7 +45,7 @@ def create_services(injectables: Dict[str, object] = {}) -> Dict[str, Service]: injector = DependencyInjector() injector.add_injectables(**injectables) - return injector.build_classes(ServiceMeta.services) + return injector.build_classes(service_registry) def snake_case(string: str) -> str: diff --git a/tests/unit_tests/core/test_service.py b/tests/unit_tests/core/test_service.py new file mode 100644 index 000000000..1d63033bf --- /dev/null +++ b/tests/unit_tests/core/test_service.py @@ -0,0 +1,21 @@ +import mock + +from server.core import Service + + +def test_service_registry(): + with mock.patch("server.core.service.service_registry", {}) as registry: + class Foo(Service): + pass + + assert registry["foo"] is Foo + assert registry == {"foo": Foo} + + +def test_service_registry_name_override(): + with mock.patch("server.core.service.service_registry", {}) as registry: + class Foo(Service, name="FooService"): + pass + + assert registry["FooService"] is Foo + assert registry == {"FooService": Foo}