Skip to content

Commit

Permalink
Migrate from peewee to sqlalchemy
Browse files Browse the repository at this point in the history
Signed-off-by: Eduard Kaverinskyi <[email protected]>
  • Loading branch information
EduKav1813 committed Nov 5, 2024
1 parent 7894fd2 commit 0f997eb
Show file tree
Hide file tree
Showing 25 changed files with 530 additions and 238 deletions.
5 changes: 2 additions & 3 deletions helpers/db_create.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import logging
from datetime import datetime
from whois.database import db, Device, User

from whois.data.db.database import Device, User, db

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("db_create")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bitfield.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase

from whois.types.bitfield import BitField
from whois.entity.bitfield import BitField


class BitFieldTest(TestCase):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from whois.web import app
from unittest import TestCase

from whois.web import app


class ApiTestCase(TestCase):

Expand All @@ -11,4 +12,6 @@ def setUp(self):
def test_index(self):
"""User should be able to access the index page"""
response = self.app.get("/")
assert response.status_code == 200
assert (
response.status_code == 200
), f"Actual response code: {response.status_code}"
1 change: 0 additions & 1 deletion tests/test_mikrotik.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ def test_parse_duration():
for case, expected in data:
result = parse_duration(case)
assert result == expected

2 changes: 1 addition & 1 deletion whois/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from whois.web import app
import whois.settings as settings
from whois.web import app

app.run(host=settings.host)
3 changes: 3 additions & 0 deletions whois/data/db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sqlalchemy.orm import declarative_base

Base = declarative_base()
39 changes: 39 additions & 0 deletions whois/data/db/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from typing import Callable

import sqlalchemy as db

from whois.data.db.base import Base
from whois.data.table.device import DeviceTable
from whois.data.table.user import UserTable


class Database:
"""Represents the Database connection."""

def __init__(self, db_url: str = None):
if not db_url:
db_url = os.environ.get("APP_DB_URL", "sqlite:///whohacks.sqlite")
self.engine = db.create_engine(db_url)
self.metadata = db.MetaData()
self.connection = None

self.user_table = UserTable()
self.device_table = DeviceTable()
self.create_db()

@property
def is_connected(self) -> bool:
return self.connection is not None

def connect(self) -> None:
self.connection = self.engine.connect()

def disconnect(self) -> None:
if not self.connection:
raise RuntimeError("Cannot close database connection - already closed")
self.connection.close()

def create_db(self) -> None:
"""Ensure that the database exists with given schema."""
Base.metadata.create_all(self.engine)
22 changes: 22 additions & 0 deletions whois/data/db/mapper/device_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from whois.data.table.device import DeviceTable
from whois.entity.device import Device


def device_to_devicetable_mapper(device: Device) -> DeviceTable:
return DeviceTable(
mac_address=device.mac_address,
hostname=device.username,
last_seen=device.last_seen,
owner=device.owner,
flags=device.flags,
)


def devicetable_to_device_mapper(device: DeviceTable) -> Device:
return Device(
mac_address=device.mac_address,
hostname=device.username,
last_seen=device.last_seen,
owner=device.owner,
flags=device.flags,
)
22 changes: 22 additions & 0 deletions whois/data/db/mapper/user_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from whois.data.table.user import UserTable
from whois.entity.user import User


def user_to_usertable_mapper(user: User) -> UserTable:
return UserTable(
id=user.id,
username=user.username,
password=user._password,
display_name=user.display_name,
flags=user.flags,
)


def usertable_to_user_mapper(user: UserTable) -> User:
return User(
id=user.id,
username=user.username,
_password=user.password,
display_name=user.display_name,
flags=user.flags,
)
48 changes: 48 additions & 0 deletions whois/data/repository/device_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from datetime import datetime, timedelta, timezone
from typing import List

from sqlalchemy.orm import Session

from whois.data.db.database import Database
from whois.data.db.mapper.device_mapper import (device_to_devicetable_mapper,
devicetable_to_device_mapper)
from whois.data.table.device import DeviceTable
from whois.entity.device import Device


class DeviceRepository:

def __init__(self, database: Database) -> None:
self.database = database

def insert(self, device: Device) -> None:
with Session(self.database.engine) as session:
session.add(device_to_devicetable_mapper(device))
session.commit()

def update(self, device: Device) -> None:
with Session(self.database.engine) as session:
device_orm = (
session.query(DeviceTable)
.where(DeviceTable.mac_address == device.mac_address)
.one()
)
device_orm.hostname = device.hostname
device_orm.last_seen = device.last_seen
device_orm.owner = device.owner
device_orm.flags = device.flags
session.commit()

def get_by_mac_address(self, mac_address: str) -> Device:
with Session(self.database.engine) as session:
device_orm = (
session.query(DeviceTable)
.where(DeviceTable.mac_address == mac_address)
.one()
)
return map(devicetable_to_device_mapper, device_orm)

def get_all(self) -> List[Device]:
with Session(self.database.engine) as session:
devices_orm = session.query(DeviceTable).all()
return list(map(devicetable_to_device_mapper, devices_orm))
47 changes: 47 additions & 0 deletions whois/data/repository/user_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List

from sqlalchemy.orm import Session

from whois.data.db.database import Database
from whois.data.db.mapper.user_mapper import (user_to_usertable_mapper,
usertable_to_user_mapper)
from whois.data.table.user import UserTable
from whois.entity.user import User


class UserRepository:
def __init__(self, database: Database) -> None:
self.database = database

def insert(self, user: User) -> None:
with Session(self.database.engine) as session:
session.add(user_to_usertable_mapper(user))
session.commit()

def update(self, user: User) -> None:
with Session(self.database.engine) as session:
user_orm = session.query(UserTable).where(UserTable.id == user.id).one()
user_orm.username = user.username
user_orm.password = user.password
user_orm.display_name = user.display_name
user_orm.flags = user.flags
session.commit()

def get_all(self) -> List[User]:
with Session(self.database.engine) as session:
users_orm = session.query(UserTable).all()
return list(map(usertable_to_user_mapper, users_orm))

def get_by_username(self, username: str) -> User:
with Session(self.database.engine) as session:
user_orm = (
session.query(UserTable).where(UserTable.username == username).one()
)

return usertable_to_user_mapper(user_orm)

def get_by_id(self, id: int) -> User:
with Session(self.database.engine) as session:
user_orm = session.query(UserTable).where(UserTable.id == id).one()

return usertable_to_user_mapper(user_orm)
29 changes: 29 additions & 0 deletions whois/data/table/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from sqlalchemy import Column, ForeignKey
from sqlalchemy.types import VARCHAR, Integer, String

from whois.data.db.base import Base
from whois.data.type.bitfield import BitField
from whois.data.type.iso_date_time_field import IsoDateTimeField


class DeviceTable(Base):
"""Represents the 'device' table in the database.
Columns:
mac_address: str(17) (Primary key)
hostname: str (Unique)
last_seen: IsoDateTimeField
owner: int (Foreign Key -> user.id)
flags: BitField (Nullable)
"""

__tablename__ = "device"

mac_address = Column(VARCHAR(17), primary_key=True, unique=True)
hostname = Column(String, nullable=True)
last_seen = IsoDateTimeField()
owner = Column(Integer, ForeignKey("user.id"), nullable=True, name="user_id")
flags = Column(BitField, nullable=True)

def __str__(self) -> str:
return str(self.mac_address)
24 changes: 24 additions & 0 deletions whois/data/table/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from sqlalchemy import Column, Integer, String

from whois.data.db.base import Base
from whois.data.type.bitfield import BitField


class UserTable(Base):
"""Represents the 'user' table in the database.
Columns:
id: int (Primary key)
username: str (Unique)
password: str (Nullable)
display_name: str
flags: BitField (Nullable)
"""

__tablename__ = "user"

id = Column(Integer, primary_key=True)
username = Column(String, unique=True)
password = Column(String, nullable=True)
display_name = Column(String)
flags = Column(BitField, nullable=True)
26 changes: 26 additions & 0 deletions whois/data/type/bitfield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sqlalchemy.types as types
from sqlalchemy.ext.mutable import Mutable

from whois.entity.bitfield import BitField


class BitField(BitField, types.TypeDecorator, Mutable):
impl = types.Integer()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def process_bind_param(self, value, dialect):
"""Convert BitField to integer before storing in the database."""
if isinstance(value, BitField):
return int(value)
elif isinstance(value, int):
return value
return 0 # Default to 0 if None or invalid type

def process_result_value(self, value, dialect):
"""Convert integer from database back into a BitField instance."""
bitfield = BitField()
if value is not None:
bitfield._flags = value # directly set flags based on stored integer
return bitfield
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def db_value(self, value: datetime) -> str:

def python_value(self, value: str) -> datetime:
if value:
return datetime.fromisoformat(value)
return datetime.fromisoformat(value)
Loading

0 comments on commit 0f997eb

Please sign in to comment.