Skip to content

Commit

Permalink
Merge pull request #22 from codemation/improved-model-filtering-condi…
Browse files Browse the repository at this point in the history
…tions

Improved model filtering conditions
  • Loading branch information
codemation authored Dec 6, 2021
2 parents d33e840 + 477ba55 commit 8a9107c
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 50 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ Note: At this point only the models have been created, but nothing is saved in t
```python
# get all hr managers currently employed
managers = await Employee.filter(
position=hr_manager,
is_employed=True
Employee.position==hr_manager, # conditional
is_employed=True # key-word argument
)

first_100_employees = await Employee.all(
limit=100
)

```
Expand Down Expand Up @@ -264,7 +268,4 @@ class Journey(DataBaseModel):
trip_id: str = PrimaryKey(default=get_uuid4)
waypoints: List[Optional[Coordinate]]




```
```
34 changes: 31 additions & 3 deletions docs/model-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,35 @@ managers = await Employee.filter(
)
```
##### Filtering - Operators
`DataBaseModel`s are equipped with operator methods allow a more complete filtering of desired objects:
`DataBaseModel`s can be filtered using `>`, `>=`, `<=`, `<`, `==`, and a `.matches([value1, value2, value3])`

```python
# conditionals
mid_salary_employees = await Employees.filter(
Employees.salary >= 30000,
Employees.salary <= 40000
)

mid_salary_employees = await Employees.filter(
Employees.salary.matches([30000, 40000])
)

mid_salary_employees = await Employees.filter(
Employees.salary == 30000,
)

# combining conditionals with keyword args
mid_salary_employees = await Employees.filter(
Employees.OR(
Employees.salary >= 30000,
Employees.salary.matches([20000, 40000])
),
is_employed = True
)

```

`DataBaseModel`s are also equipped with operator methods allowing for additional filtering of desired objects

```python
# greater than or equal
Expand Down Expand Up @@ -186,7 +214,7 @@ Updates to `DataBaseModel` objects must be done directly via an object instance,


```python
all_employees = await Employees.select('*')
all_employees = await Employees.all()

# update is_employed to False for all employees

Expand All @@ -208,7 +236,7 @@ for employee in all_employees:
Much like updates, `DataBaseModel` objects can only be deleted by directly calling the `.delete()` method of an object instance.

```python
all_employees = await Employees.select('*')
all_employees = await Employees.all()

# delete latest employee
await all_employees[-1].delete()
Expand Down
225 changes: 187 additions & 38 deletions pydbantic/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from sqlalchemy.sql.functions import count
from pydantic import BaseModel, Field
import typing
from typing import Optional, Union, List
from typing import Iterable, Optional, Union, List, Any
import sqlalchemy
from sqlalchemy import select, func
from sqlalchemy import select, func, or_
from pickle import dumps, loads

class _Generic(BaseModel):
Expand Down Expand Up @@ -46,8 +45,90 @@ def Default(default=...):
if isinstance(default, type(lambda x: x)):
return Field(default_factory=default)
return Field(default=default)
class DataBaseModelCondition:
def __init__(
self,
description: str,
condition: sqlalchemy.sql.elements.BinaryExpression,
values
):
self.description = description
self.condition = condition
self.values = values
def __repr__(self):
return self.description
def str(self):
return self.description

class DataBaseModelAttribute:
def __init__(
self,
name: str,
column: sqlalchemy.sql.schema.Column,
table,
serialized: bool = False
):
self.name = name
self.column = column
self.table = table
self.serialized = serialized
def process_value(self, value):
if self.name in self.table['foreign_keys']:
foreign_table_name = value.__class__.__name__
primary_key = value.__class__.__metadata__.tables[foreign_table_name]['primary_key']
return getattr(value, primary_key)
if self.serialized:
return dumps(value)

return value

def __lt__(self, value) -> DataBaseModelCondition:
values = self.process_value(value)
return DataBaseModelCondition(
f"{self.name} < {values}",
self.column < self.process_value(value),
(values,)
)

def __le__(self, value) -> DataBaseModelCondition:
values = self.process_value(value)
return DataBaseModelCondition(
f"{self.name} <= {values}",
self.column <= self.process_value(value),
(values,)
)

def __gt__(self, value) -> DataBaseModelCondition:
values = self.process_value(value)
return DataBaseModelCondition(
f"{self.name} > {values}",
self.column > self.process_value(value),
(values,)
)

def __ge__(self, value) -> DataBaseModelCondition:
values = self.process_value(value)
return DataBaseModelCondition(
f"{self.name} >= {values}",
self.column >= self.process_value(value),
(values,)
)

def __eq__(self, value) -> DataBaseModelCondition:
values = self.process_value(value)
return DataBaseModelCondition(
f"{self.name} == {values}",
self.column == self.process_value(value),
(values,)
)

def matches(self, choices: List[Any]) -> DataBaseModelCondition:
choices = [self.process_value(value) for value in choices]
return DataBaseModelCondition(
f"{self.name} in {choices}",
self.column.in_(choices),
tuple(choices)
)

class DataBaseModel(BaseModel):
__metadata__: BaseMeta = BaseMeta()
Expand Down Expand Up @@ -114,6 +195,25 @@ def generate_sqlalchemy_table(cls):
cls.__metadata__.metadata,
*cls.convert_fields_to_columns()
)
cls.generate_model_attributes()
@classmethod
def generate_model_attributes(cls):
name = cls.__name__
for c, column in cls.__metadata__.tables[name]['column_map'].items():
sql_c = c
if c in cls.__metadata__.tables[name]['foreign_keys']:
sql_c = cls.__metadata__.tables[name]['foreign_keys'][c]
setattr(
cls,
c,
DataBaseModelAttribute(
c,
cls.__metadata__.tables[name]['table'].c[sql_c],
cls.__metadata__.tables[name],
column[2]
)
)


@classmethod
def convert_fields_to_columns(
Expand Down Expand Up @@ -294,33 +394,29 @@ async def save(self):
if not exists:
return await self.insert()
return await self.update()

@classmethod
def where(cls, query, where: dict, *conditions):
table = cls.get_table()
conditions = list(conditions)

values = []

for cond, value in where.items():
# check if cond is a foreign key, handle pulling foreign references matching query
if cond in cls.__metadata__.tables[cls.__name__]['foreign_keys']:
foreign_column_name = cls.__metadata__.tables[cls.__name__]['foreign_keys'][cond]
foreign_primary_key = cls.__metadata__.tables[value.__class__.__name__]['primary_key']
conditions.append(table.c[foreign_column_name]==getattr(value, foreign_primary_key))
continue
if cond not in table.c:
if not isinstance(cond, DataBaseModelAttribute) and hasattr(cls, cond):
cond = getattr(cls, cond)
else:
raise Exception(f"{cond} is not a valid column in {table}")

conditions.append(cond == value)
query_value = value

serialized = cls.__metadata__.tables[cls.__name__]['column_map'][cond][2]

if serialized:
if cond.serialized:
query_value = dumps(value)

conditions.append(table.c[cond] == query_value)
values.append(query_value)
values = []
for condition in conditions:
query = query.where(condition)
query = query.where(condition.condition)
if isinstance(condition.values, tuple):
values.extend(condition.values)

return query, tuple(values)

@classmethod
Expand All @@ -329,41 +425,88 @@ def get_table(cls):
cls.generate_sqlalchemy_table()

return cls.__metadata__.tables[cls.__name__]['table']


@classmethod
def OR(cls, *conditions, **filters) -> DataBaseModelCondition:
table = cls.get_table()
conditions = list(conditions)

for cond, value in filters.items():
if not isinstance(cond, DataBaseModelAttribute) and hasattr(cls, cond):
cond = getattr(cls, cond)
else:
raise Exception(f"{cond} is not a valid column in {table}")

conditions.append(cond == value)
values = []
for cond in conditions:
if isinstance(cond.values, tuple):
values.extend(cond.values)

return DataBaseModelCondition(
" OR ".join([str(cond) for cond in conditions]),
or_(*[cond.condition for cond in conditions]),
values=tuple(values)
)


@classmethod
def gt(cls, column, value):
def gt(cls, column, value) -> DataBaseModelCondition:
table = cls.get_table()
if not column in table.c:
raise Exception(f"{column} is not a valid column in {table}")
return table.c[column] > value

return DataBaseModelCondition(
f"{column} > {value}",
table.c[column] > value,
value
)

@classmethod
def gte(cls, column, value):
def gte(cls, column, value) -> DataBaseModelCondition:
table = cls.get_table()
if not column in table.c:
raise Exception(f"{column} is not a valid column in {table}")
return table.c[column] >= value
return DataBaseModelCondition(
f"{column} >= {value}",
table.c[column] >= value,
value
)


@classmethod
def lt(cls, column, value):
def lt(cls, column, value) -> DataBaseModelCondition:
table = cls.get_table()
if not column in table.c:
raise Exception(f"{column} is not a valid column in {table}")
return table.c[column] < value
return DataBaseModelCondition(
f"{column} < {value}",
table.c[column] < value,
value
)

@classmethod
def lte(cls, column, value):
def lte(cls, column, value) -> DataBaseModelCondition:
table = cls.get_table()
if not column in table.c:
raise Exception(f"{column} is not a valid column in {table}")
return table.c[column] <= value
return DataBaseModelCondition(
f"{column} <= {value}",
table.c[column] >= value,
value
)

@classmethod
def contains(cls, column, value):
def contains(cls, column, value) -> DataBaseModelCondition:
table = cls.get_table()
if not column in table.c:
raise Exception(f"{column} is not a valid column in {table}")
return table.c[column].contains(value)

return DataBaseModelCondition(
f"{value} in {column}",
table.c[column].contains(value),
value
)

@classmethod
def desc(cls, column):
Expand Down Expand Up @@ -568,7 +711,10 @@ async def filter(
if not order_by is None:
sel = sel.order_by(order_by)

results = await database.fetch(sel, cls.__name__, values)
results = await database.fetch(sel, cls.__name__, tuple(values))

normalized_results = cls.normalize(results)

rows = []
for result in cls.normalize(results):
values = {}
Expand Down Expand Up @@ -666,12 +812,15 @@ async def insert(self):
)

@classmethod
async def get(cls, **p_key):
for k in p_key:
primary_key = cls.__metadata__.tables[cls.__name__]['primary_key']
if k != cls.__metadata__.tables[cls.__name__]['primary_key']:
raise f"Expected primary key {primary_key}=<value>"
result = await cls.select('*', where={**p_key})
async def get(cls, *p_key_condition, **p_key):
if not p_key_condition:
for k in p_key:
primary_key = cls.__metadata__.tables[cls.__name__]['primary_key']
if k != cls.__metadata__.tables[cls.__name__]['primary_key']:
raise f"Expected primary key {primary_key}=<value>"
p_key_condition = [getattr(cls, primary_key) == p_key[k]]

result = await cls.filter(*p_key_condition)
return result[0] if result else None

@classmethod
Expand Down
Loading

0 comments on commit 8a9107c

Please sign in to comment.