Skip to content

Commit

Permalink
refactor: clean up shillelagh implementation
Browse files Browse the repository at this point in the history
initially i was implementing a sql interface within the dto objects, as
the solution grew it made more sense to be able to consolidate this into
the shillelagh extension, ensuring that if the sql interface isn't in use
the normal dto isn't burdened by imports

i've kept some helper methods to concat dictionaries of annotations, this
should not affect the workings of the solution

refs #31
  • Loading branch information
devraj committed May 29, 2024
1 parent 0dcc5e7 commit 2f45db0
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 128 deletions.
7 changes: 6 additions & 1 deletion gallagher/cc/alarms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from typing import Optional

from ..core import APIEndpoint, EndpointConfig, Capabilities
from ..core import (
APIEndpoint,
EndpointConfig,
Capabilities,
)

from ...dto.ref import (
AlarmRef,
Expand Down Expand Up @@ -60,6 +64,7 @@ async def get_config(cls) -> EndpointConfig:
endpoint=Capabilities.CURRENT.features.alarms.alarms,
dto_list=AlarmSummaryResponse,
dto_retrieve=AlarmDetail,
sql_model=AlarmSummary, # Temporary
)

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion gallagher/cc/cardholders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ..core import Capabilities, APIEndpoint, EndpointConfig

# TODO: remove
from ...dto.summary import CardholderSummary

from ...dto.detail import (
Expand All @@ -29,9 +30,9 @@ class Cardholder(APIEndpoint):
async def get_config(cls) -> EndpointConfig:
return EndpointConfig(
endpoint=Capabilities.CURRENT.features.cardholders.cardholders,
sql_model=CardholderSummary, # Temporary
dto_list=CardholderSummaryResponse,
dto_retrieve=CardholderDetail,
sql_model=CardholderSummary, # Temporary
)

@classmethod
Expand Down
69 changes: 17 additions & 52 deletions gallagher/dto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ class AppBaseModel(BaseModel):
allow_extra=True,
)

@classmethod
def _accumulated_annotations(cls) -> dict:
"""Return a dictionary of all annotations
This method is used to return a dictionary of all the
annotations from itself and it's parent classes.
It is intended for use by the shillelagh extension
"""
annotations = cls.__annotations__.copy() # TODO: should we make copies?
for base in cls.__bases__:
if issubclass(base, BaseModel):
# Copy annotations from classes that are pydantic models
# they are the only things that form part of the response
annotations.update(base.__annotations__)
return annotations.items()

# Set to the last time each response was retrieved
# If it's set to None then the response was either created
# by the API client or it wasn't retrieved from the server
Expand All @@ -137,58 +154,6 @@ def model_post_init(self, __context) -> None:
"""
self._good_known_since = datetime.now()

@classmethod
def _shillelagh_columns(cls) -> dict:
"""Return the model as a __shillelagh__ compatible attribute config
Rules here are that we translate as many dictionary vars into
a __shillelagh__ compatible format.
If they are hrefs to other children then we select the id field for
each one of those objects
"""
from shillelagh.fields import (
Field,
Integer,
String,
Boolean,
Blob,
Collection,
Date,
DateTime,
Float,
ISODate,
ISODateTime,
IntBoolean,
StringBlob,
StringBoolean,
StringDate,
StringDateTime,
StringDecimal,
StringDuration,
StringInteger,
StringTime,
)

_map = {
int: Integer,
str: String,
bool: Boolean,
bytes: Blob,
list: Collection,
datetime: DateTime,
float: Float,
}

# Make a key, value pair of all the class attributes
# that are have a primitive type
table_fields = {
key: _map[value]()
for key, value in cls.__annotations__.items() # annotations not fields
if not key.startswith("_") and value in _map
}
return table_fields

def __repr__(self) -> str:
"""Return a string representation of the model
Expand Down
73 changes: 53 additions & 20 deletions gallagher/ext/shillelagh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import logging
import urllib
from datetime import datetime

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
# logging.basicConfig(level=logging.ERROR, format='%(asctime)s %(levelname)s %(message)s')
Expand Down Expand Up @@ -42,6 +43,9 @@
String,
Order,
Boolean,
DateTime,
Collection,
Blob,
)

# TODO: refactor this to generic based on SQL.md
Expand All @@ -66,6 +70,22 @@ class CCAPIAdapter(Adapter):
_all_tables = alarms_tables + cardholders_tables + \
status_overrides_tables

# Maps the python types to the shillelagh fields
# initialised here so we don't have to keep redefining it
_type_map = {
int: Integer,
str: String,
bool: Boolean,
bytes: Blob,
list: Collection,
datetime: DateTime,
float: Float,
}

# API endpoint being used to parse this URL
# TODO: see if this changes with every instantiation
_api_endpoint = None

# The adapter doesn't access the filesystem.
safe = True

Expand Down Expand Up @@ -167,15 +187,38 @@ def __init__(self, uri: str, api_key: Optional[str], **kwargs: Any):
# TODO: might be redundant due to moving this up the package level
cc.api_key = api_key

self._logger.debug(f"Finding suitable adapter for {self.uri}")

self._api_endpoint = next(
(table for table in self._all_tables if self.uri == f"{table.__config__.endpoint.href}"),
None
)

if self._api_endpoint:
self._logger.debug(f"Found helper class = {self._api_endpoint}")
else:
self._logger.debug("No suitable adapter found.")

def get_columns(self) -> Dict[str, Field]:
"""Return a pyndatic DTO as shil compatible attribute config
for table in self._all_tables:
self._logger.debug(f"Finding suitable adapter for {self.uri}")
if self.uri == f"{table.__config__.endpoint.href}":
self._logger.debug(f"Found helper class = {table}")
return table.__config__.sql_model._shillelagh_columns()
return {}
Rules here are that we translate as many dictionary vars into
a __shillelagh__ compatible format.
If they are hrefs to other children then we select the id field for
each one of those objects
"""

if not self._api_endpoint:
self._logger.debug("No suitable adapter found while get_columns.")
return {}

return {
key: self._type_map[value]() # shillelagh requires an instance
for key, value in \
self._api_endpoint.__config__.sql_model._accumulated_annotations()
if not key.startswith("_") and value in self._type_map
}


def get_data( # pylint: disable=too-many-locals
Expand All @@ -187,22 +230,12 @@ def get_data( # pylint: disable=too-many-locals
**kwargs: Any,
) -> Iterator[Row]:

cardholders = asyncio.run(Cardholder.list())
# cardholders = await Cardholder.list()
dto_list = asyncio.run(self._api_endpoint.list())

rindex = 0

for row in cardholders.results:
for row in dto_list.results:

yield {
"rowid": row.id,
"id": row.id,
"authorised": row.authorised,
"first_name": row.first_name,
"last_name": row.last_name,
'rowid': row.id, # Append this for shillelagh
**row.dict(), # Rest of the responses
}

rindex += 1
if rindex == limit:
break

56 changes: 2 additions & 54 deletions gallagher/ext/shillelagh/utils.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,2 @@



@classmethod
def _shillelagh_columns(cls) -> dict:
"""Return the model as a __shillelagh__ compatible attribute config
Rules here are that we translate as many dictionary vars into
a __shillelagh__ compatible format.
If they are hrefs to other children then we select the id field for
each one of those objects
"""
from shillelagh.fields import (
Field,
Integer,
String,
Boolean,
Blob,
Collection,
Date,
DateTime,
Float,
ISODate,
ISODateTime,
IntBoolean,
StringBlob,
StringBoolean,
StringDate,
StringDateTime,
StringDecimal,
StringDuration,
StringInteger,
StringTime,
)

_map = {
int: Integer,
str: String,
bool: Boolean,
bytes: Blob,
list: Collection,
datetime: DateTime,
float: Float,
}

# Make a key, value pair of all the class attributes
# that are have a primitive type
table_fields = {
key: _map[value]()
for key, value in cls.__annotations__.items() # annotations not fields
if not key.startswith("_") and value in _map
}
return table_fields
""" Utilities to help Shillelagh CLI commands
"""

0 comments on commit 2f45db0

Please sign in to comment.