diff --git a/gallagher/cc/alarms/__init__.py b/gallagher/cc/alarms/__init__.py index 6f781779..9f824827 100644 --- a/gallagher/cc/alarms/__init__.py +++ b/gallagher/cc/alarms/__init__.py @@ -6,7 +6,11 @@ from typing import Optional -from ..core import APIEndpoint, EndpointConfig, Capabilities +from ..core import ( + APIEndpoint, + EndpointConfig, + Capabilities, +) from ...dto.ref import ( AlarmRef, @@ -60,6 +64,7 @@ async def get_config(cls) -> EndpointConfig: endpoint=Capabilities.CURRENT.features.alarms.alarms, dto_list=AlarmSummaryResponse, dto_retrieve=AlarmDetail, + sql_model=AlarmSummary, # Temporary ) @classmethod diff --git a/gallagher/cc/cardholders/__init__.py b/gallagher/cc/cardholders/__init__.py index cb2894f2..77328003 100644 --- a/gallagher/cc/cardholders/__init__.py +++ b/gallagher/cc/cardholders/__init__.py @@ -4,6 +4,7 @@ from ..core import Capabilities, APIEndpoint, EndpointConfig +# TODO: remove from ...dto.summary import CardholderSummary from ...dto.detail import ( @@ -29,9 +30,9 @@ class Cardholder(APIEndpoint): async def get_config(cls) -> EndpointConfig: return EndpointConfig( endpoint=Capabilities.CURRENT.features.cardholders.cardholders, - sql_model=CardholderSummary, # Temporary dto_list=CardholderSummaryResponse, dto_retrieve=CardholderDetail, + sql_model=CardholderSummary, # Temporary ) @classmethod diff --git a/gallagher/dto/utils.py b/gallagher/dto/utils.py index 451da81c..7b0a385a 100644 --- a/gallagher/dto/utils.py +++ b/gallagher/dto/utils.py @@ -121,6 +121,23 @@ class AppBaseModel(BaseModel): allow_extra=True, ) + @classmethod + def _accumulated_annotations(cls) -> dict: + """Return a dictionary of all annotations + + This method is used to return a dictionary of all the + annotations from itself and it's parent classes. + + It is intended for use by the shillelagh extension + """ + annotations = cls.__annotations__.copy() # TODO: should we make copies? + for base in cls.__bases__: + if issubclass(base, BaseModel): + # Copy annotations from classes that are pydantic models + # they are the only things that form part of the response + annotations.update(base.__annotations__) + return annotations.items() + # Set to the last time each response was retrieved # If it's set to None then the response was either created # by the API client or it wasn't retrieved from the server @@ -137,58 +154,6 @@ def model_post_init(self, __context) -> None: """ self._good_known_since = datetime.now() - @classmethod - def _shillelagh_columns(cls) -> dict: - """Return the model as a __shillelagh__ compatible attribute config - - Rules here are that we translate as many dictionary vars into - a __shillelagh__ compatible format. - - If they are hrefs to other children then we select the id field for - each one of those objects - """ - from shillelagh.fields import ( - Field, - Integer, - String, - Boolean, - Blob, - Collection, - Date, - DateTime, - Float, - ISODate, - ISODateTime, - IntBoolean, - StringBlob, - StringBoolean, - StringDate, - StringDateTime, - StringDecimal, - StringDuration, - StringInteger, - StringTime, - ) - - _map = { - int: Integer, - str: String, - bool: Boolean, - bytes: Blob, - list: Collection, - datetime: DateTime, - float: Float, - } - - # Make a key, value pair of all the class attributes - # that are have a primitive type - table_fields = { - key: _map[value]() - for key, value in cls.__annotations__.items() # annotations not fields - if not key.startswith("_") and value in _map - } - return table_fields - def __repr__(self) -> str: """Return a string representation of the model diff --git a/gallagher/ext/shillelagh/__init__.py b/gallagher/ext/shillelagh/__init__.py index 1fb75065..3e69c96a 100644 --- a/gallagher/ext/shillelagh/__init__.py +++ b/gallagher/ext/shillelagh/__init__.py @@ -9,6 +9,7 @@ import asyncio import logging import urllib +from datetime import datetime logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') # logging.basicConfig(level=logging.ERROR, format='%(asctime)s %(levelname)s %(message)s') @@ -42,6 +43,9 @@ String, Order, Boolean, + DateTime, + Collection, + Blob, ) # TODO: refactor this to generic based on SQL.md @@ -66,6 +70,22 @@ class CCAPIAdapter(Adapter): _all_tables = alarms_tables + cardholders_tables + \ status_overrides_tables + # Maps the python types to the shillelagh fields + # initialised here so we don't have to keep redefining it + _type_map = { + int: Integer, + str: String, + bool: Boolean, + bytes: Blob, + list: Collection, + datetime: DateTime, + float: Float, + } + + # API endpoint being used to parse this URL + # TODO: see if this changes with every instantiation + _api_endpoint = None + # The adapter doesn't access the filesystem. safe = True @@ -167,15 +187,38 @@ def __init__(self, uri: str, api_key: Optional[str], **kwargs: Any): # TODO: might be redundant due to moving this up the package level cc.api_key = api_key + self._logger.debug(f"Finding suitable adapter for {self.uri}") + + self._api_endpoint = next( + (table for table in self._all_tables if self.uri == f"{table.__config__.endpoint.href}"), + None + ) + + if self._api_endpoint: + self._logger.debug(f"Found helper class = {self._api_endpoint}") + else: + self._logger.debug("No suitable adapter found.") def get_columns(self) -> Dict[str, Field]: + """Return a pyndatic DTO as shil compatible attribute config - for table in self._all_tables: - self._logger.debug(f"Finding suitable adapter for {self.uri}") - if self.uri == f"{table.__config__.endpoint.href}": - self._logger.debug(f"Found helper class = {table}") - return table.__config__.sql_model._shillelagh_columns() - return {} + Rules here are that we translate as many dictionary vars into + a __shillelagh__ compatible format. + + If they are hrefs to other children then we select the id field for + each one of those objects + """ + + if not self._api_endpoint: + self._logger.debug("No suitable adapter found while get_columns.") + return {} + + return { + key: self._type_map[value]() # shillelagh requires an instance + for key, value in \ + self._api_endpoint.__config__.sql_model._accumulated_annotations() + if not key.startswith("_") and value in self._type_map + } def get_data( # pylint: disable=too-many-locals @@ -187,22 +230,12 @@ def get_data( # pylint: disable=too-many-locals **kwargs: Any, ) -> Iterator[Row]: - cardholders = asyncio.run(Cardholder.list()) - # cardholders = await Cardholder.list() + dto_list = asyncio.run(self._api_endpoint.list()) - rindex = 0 - - for row in cardholders.results: + for row in dto_list.results: yield { - "rowid": row.id, - "id": row.id, - "authorised": row.authorised, - "first_name": row.first_name, - "last_name": row.last_name, + 'rowid': row.id, # Append this for shillelagh + **row.dict(), # Rest of the responses } - rindex += 1 - if rindex == limit: - break - diff --git a/gallagher/ext/shillelagh/utils.py b/gallagher/ext/shillelagh/utils.py index 1a70f12c..a8fb4b37 100644 --- a/gallagher/ext/shillelagh/utils.py +++ b/gallagher/ext/shillelagh/utils.py @@ -1,54 +1,2 @@ - - - -@classmethod -def _shillelagh_columns(cls) -> dict: - """Return the model as a __shillelagh__ compatible attribute config - - Rules here are that we translate as many dictionary vars into - a __shillelagh__ compatible format. - - If they are hrefs to other children then we select the id field for - each one of those objects - """ - from shillelagh.fields import ( - Field, - Integer, - String, - Boolean, - Blob, - Collection, - Date, - DateTime, - Float, - ISODate, - ISODateTime, - IntBoolean, - StringBlob, - StringBoolean, - StringDate, - StringDateTime, - StringDecimal, - StringDuration, - StringInteger, - StringTime, - ) - - _map = { - int: Integer, - str: String, - bool: Boolean, - bytes: Blob, - list: Collection, - datetime: DateTime, - float: Float, - } - - # Make a key, value pair of all the class attributes - # that are have a primitive type - table_fields = { - key: _map[value]() - for key, value in cls.__annotations__.items() # annotations not fields - if not key.startswith("_") and value in _map - } - return table_fields \ No newline at end of file +""" Utilities to help Shillelagh CLI commands +""" \ No newline at end of file