diff --git a/helpers/db_create.py b/helpers/db_create.py index a876c2e..f0ada84 100644 --- a/helpers/db_create.py +++ b/helpers/db_create.py @@ -1,7 +1,6 @@ -import os import logging -from datetime import datetime -from whois.database import db, Device, User + +from whois.data.db.database import Device, User, db logging.basicConfig(level=logging.INFO) logger = logging.getLogger("db_create") diff --git a/tests/test_bitfield.py b/tests/test_bitfield.py index ac6dc43..ae1840a 100644 --- a/tests/test_bitfield.py +++ b/tests/test_bitfield.py @@ -1,6 +1,6 @@ from unittest import TestCase -from whois.types.bitfield import BitField +from whois.entity.bitfield import BitField class BitFieldTest(TestCase): diff --git a/tests/test_integration.py b/tests/test_integration.py index d582578..afa2556 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,6 +1,7 @@ -from whois.web import app from unittest import TestCase +from whois.web import app + class ApiTestCase(TestCase): @@ -11,4 +12,6 @@ def setUp(self): def test_index(self): """User should be able to access the index page""" response = self.app.get("/") - assert response.status_code == 200 + assert ( + response.status_code == 200 + ), f"Actual response code: {response.status_code}" diff --git a/tests/test_mikrotik.py b/tests/test_mikrotik.py index d1d63bf..aeb5418 100644 --- a/tests/test_mikrotik.py +++ b/tests/test_mikrotik.py @@ -19,4 +19,3 @@ def test_parse_duration(): for case, expected in data: result = parse_duration(case) assert result == expected - diff --git a/whois/__main__.py b/whois/__main__.py index 5af7b3c..64eff59 100644 --- a/whois/__main__.py +++ b/whois/__main__.py @@ -1,4 +1,4 @@ -from whois.web import app import whois.settings as settings +from whois.web import app app.run(host=settings.host) diff --git a/whois/data/db/base.py b/whois/data/db/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/whois/data/db/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/whois/data/db/database.py b/whois/data/db/database.py new file mode 100644 index 0000000..8cb56b2 --- /dev/null +++ b/whois/data/db/database.py @@ -0,0 +1,39 @@ +import os +from typing import Callable + +import sqlalchemy as db + +from whois.data.db.base import Base +from whois.data.table.device import DeviceTable +from whois.data.table.user import UserTable + + +class Database: + """Represents the Database connection.""" + + def __init__(self, db_url: str = None): + if not db_url: + db_url = os.environ.get("APP_DB_URL", "sqlite:///whohacks.sqlite") + self.engine = db.create_engine(db_url) + self.metadata = db.MetaData() + self.connection = None + + self.user_table = UserTable() + self.device_table = DeviceTable() + self.create_db() + + @property + def is_connected(self) -> bool: + return self.connection is not None + + def connect(self) -> None: + self.connection = self.engine.connect() + + def disconnect(self) -> None: + if not self.connection: + raise RuntimeError("Cannot close database connection - already closed") + self.connection.close() + + def create_db(self) -> None: + """Ensure that the database exists with given schema.""" + Base.metadata.create_all(self.engine) diff --git a/whois/data/db/mapper/device_mapper.py b/whois/data/db/mapper/device_mapper.py new file mode 100644 index 0000000..49e399b --- /dev/null +++ b/whois/data/db/mapper/device_mapper.py @@ -0,0 +1,22 @@ +from whois.data.table.device import DeviceTable +from whois.entity.device import Device + + +def device_to_devicetable_mapper(device: Device) -> DeviceTable: + return DeviceTable( + mac_address=device.mac_address, + hostname=device.username, + last_seen=device.last_seen, + owner=device.owner, + flags=device.flags, + ) + + +def devicetable_to_device_mapper(device: DeviceTable) -> Device: + return Device( + mac_address=device.mac_address, + hostname=device.username, + last_seen=device.last_seen, + owner=device.owner, + flags=device.flags, + ) diff --git a/whois/data/db/mapper/user_mapper.py b/whois/data/db/mapper/user_mapper.py new file mode 100644 index 0000000..463f162 --- /dev/null +++ b/whois/data/db/mapper/user_mapper.py @@ -0,0 +1,22 @@ +from whois.data.table.user import UserTable +from whois.entity.user import User + + +def user_to_usertable_mapper(user: User) -> UserTable: + return UserTable( + id=user.id, + username=user.username, + password=user._password, + display_name=user.display_name, + flags=user.flags, + ) + + +def usertable_to_user_mapper(user: UserTable) -> User: + return User( + id=user.id, + username=user.username, + _password=user.password, + display_name=user.display_name, + flags=user.flags, + ) diff --git a/whois/data/repository/device_repository.py b/whois/data/repository/device_repository.py new file mode 100644 index 0000000..d37fef4 --- /dev/null +++ b/whois/data/repository/device_repository.py @@ -0,0 +1,48 @@ +from datetime import datetime, timedelta, timezone +from typing import List + +from sqlalchemy.orm import Session + +from whois.data.db.database import Database +from whois.data.db.mapper.device_mapper import (device_to_devicetable_mapper, + devicetable_to_device_mapper) +from whois.data.table.device import DeviceTable +from whois.entity.device import Device + + +class DeviceRepository: + + def __init__(self, database: Database) -> None: + self.database = database + + def insert(self, device: Device) -> None: + with Session(self.database.engine) as session: + session.add(device_to_devicetable_mapper(device)) + session.commit() + + def update(self, device: Device) -> None: + with Session(self.database.engine) as session: + device_orm = ( + session.query(DeviceTable) + .where(DeviceTable.mac_address == device.mac_address) + .one() + ) + device_orm.hostname = device.hostname + device_orm.last_seen = device.last_seen + device_orm.owner = device.owner + device_orm.flags = device.flags + session.commit() + + def get_by_mac_address(self, mac_address: str) -> Device: + with Session(self.database.engine) as session: + device_orm = ( + session.query(DeviceTable) + .where(DeviceTable.mac_address == mac_address) + .one() + ) + return map(devicetable_to_device_mapper, device_orm) + + def get_all(self) -> List[Device]: + with Session(self.database.engine) as session: + devices_orm = session.query(DeviceTable).all() + return list(map(devicetable_to_device_mapper, devices_orm)) diff --git a/whois/data/repository/user_repository.py b/whois/data/repository/user_repository.py new file mode 100644 index 0000000..cdefb13 --- /dev/null +++ b/whois/data/repository/user_repository.py @@ -0,0 +1,47 @@ +from typing import List + +from sqlalchemy.orm import Session + +from whois.data.db.database import Database +from whois.data.db.mapper.user_mapper import (user_to_usertable_mapper, + usertable_to_user_mapper) +from whois.data.table.user import UserTable +from whois.entity.user import User + + +class UserRepository: + def __init__(self, database: Database) -> None: + self.database = database + + def insert(self, user: User) -> None: + with Session(self.database.engine) as session: + session.add(user_to_usertable_mapper(user)) + session.commit() + + def update(self, user: User) -> None: + with Session(self.database.engine) as session: + user_orm = session.query(UserTable).where(UserTable.id == user.id).one() + user_orm.username = user.username + user_orm.password = user.password + user_orm.display_name = user.display_name + user_orm.flags = user.flags + session.commit() + + def get_all(self) -> List[User]: + with Session(self.database.engine) as session: + users_orm = session.query(UserTable).all() + return list(map(usertable_to_user_mapper, users_orm)) + + def get_by_username(self, username: str) -> User: + with Session(self.database.engine) as session: + user_orm = ( + session.query(UserTable).where(UserTable.username == username).one() + ) + + return usertable_to_user_mapper(user_orm) + + def get_by_id(self, id: int) -> User: + with Session(self.database.engine) as session: + user_orm = session.query(UserTable).where(UserTable.id == id).one() + + return usertable_to_user_mapper(user_orm) diff --git a/whois/data/table/device.py b/whois/data/table/device.py new file mode 100644 index 0000000..f8815f9 --- /dev/null +++ b/whois/data/table/device.py @@ -0,0 +1,29 @@ +from sqlalchemy import Column, ForeignKey +from sqlalchemy.types import VARCHAR, Integer, String + +from whois.data.db.base import Base +from whois.data.type.bitfield import BitField +from whois.data.type.iso_date_time_field import IsoDateTimeField + + +class DeviceTable(Base): + """Represents the 'device' table in the database. + + Columns: + mac_address: str(17) (Primary key) + hostname: str (Unique) + last_seen: IsoDateTimeField + owner: int (Foreign Key -> user.id) + flags: BitField (Nullable) + """ + + __tablename__ = "device" + + mac_address = Column(VARCHAR(17), primary_key=True, unique=True) + hostname = Column(String, nullable=True) + last_seen = IsoDateTimeField() + owner = Column(Integer, ForeignKey("user.id"), nullable=True, name="user_id") + flags = Column(BitField, nullable=True) + + def __str__(self) -> str: + return str(self.mac_address) diff --git a/whois/data/table/user.py b/whois/data/table/user.py new file mode 100644 index 0000000..eca5dd0 --- /dev/null +++ b/whois/data/table/user.py @@ -0,0 +1,24 @@ +from sqlalchemy import Column, Integer, String + +from whois.data.db.base import Base +from whois.data.type.bitfield import BitField + + +class UserTable(Base): + """Represents the 'user' table in the database. + + Columns: + id: int (Primary key) + username: str (Unique) + password: str (Nullable) + display_name: str + flags: BitField (Nullable) + """ + + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + username = Column(String, unique=True) + password = Column(String, nullable=True) + display_name = Column(String) + flags = Column(BitField, nullable=True) diff --git a/whois/data/type/bitfield.py b/whois/data/type/bitfield.py new file mode 100644 index 0000000..2745f52 --- /dev/null +++ b/whois/data/type/bitfield.py @@ -0,0 +1,26 @@ +import sqlalchemy.types as types +from sqlalchemy.ext.mutable import Mutable + +from whois.entity.bitfield import BitField + + +class BitField(BitField, types.TypeDecorator, Mutable): + impl = types.Integer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + """Convert BitField to integer before storing in the database.""" + if isinstance(value, BitField): + return int(value) + elif isinstance(value, int): + return value + return 0 # Default to 0 if None or invalid type + + def process_result_value(self, value, dialect): + """Convert integer from database back into a BitField instance.""" + bitfield = BitField() + if value is not None: + bitfield._flags = value # directly set flags based on stored integer + return bitfield diff --git a/whois/types/iso_date_time_field.py b/whois/data/type/iso_date_time_field.py similarity index 86% rename from whois/types/iso_date_time_field.py rename to whois/data/type/iso_date_time_field.py index d95dc5f..4bc64eb 100644 --- a/whois/types/iso_date_time_field.py +++ b/whois/data/type/iso_date_time_field.py @@ -12,4 +12,4 @@ def db_value(self, value: datetime) -> str: def python_value(self, value: str) -> datetime: if value: - return datetime.fromisoformat(value) \ No newline at end of file + return datetime.fromisoformat(value) diff --git a/whois/database.py b/whois/database.py deleted file mode 100644 index 856976a..0000000 --- a/whois/database.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -from datetime import datetime, timedelta, timezone - -import peewee as pw -from werkzeug.security import check_password_hash, generate_password_hash -from playhouse.db_url import connect - - -db_url = os.environ.get("APP_DB_URL", "sqlite:///whohacks.sqlite") -db = connect(db_url) - - -class User(pw.Model): - id = pw.PrimaryKeyField() - username = pw.CharField(unique=True) - _password = pw.CharField(column_name="password", null=True) - display_name = pw.CharField() - flags = pw.BitField(null=True) - - is_hidden = flags.flag(1) - is_name_anonymous = flags.flag(2) - - class Meta: - database = db - - @classmethod - def register(cls, username, password, display_name=None): - """ - Creates user and hashes his password - :param username: used in login - :param password: plain text to be hashed - :param display_name: displayed username - :return: user instance - """ - user = cls.create( - username=username, display_name=display_name - ) - user.password = password - return user - - @classmethod - def register_from_sso(cls, username, display_name=None): - """ - Creates user without any password. Such users can only login via SSO. - :param username: used in login - :param display_name: displayed username - :return: user instance - """ - user = cls.create( - username=username, display_name=display_name - ) - user._password = None - return user - - def __str__(self): - if self.is_name_anonymous or self.is_hidden: - return "anonymous" - else: - return self.display_name - - @property - def is_active(self): - return self.username is not None - - @property - def is_authenticated(self): - return True - - @property - def is_anonymous(self): - """ - Needed by flask login - :return: - """ - return False - - @property - def is_sso(self) -> bool: - return not self._password - - @property - def password(self): - return self._password - - @password.setter - def password(self, new_password): - if len(new_password) < 3: - raise Exception("too_short") - else: - self._password = generate_password_hash(new_password) - - def auth(self, password): - return check_password_hash(self.password, password) - - -class IsoDateTimeField(pw.DateTimeField): - field_type = "DATETIME" - - def db_value(self, value: datetime) -> str: - if value: - return value.isoformat() - - def python_value(self, value: str) -> datetime: - if value: - return datetime.fromisoformat(value) - - -class Device(pw.Model): - mac_address = pw.FixedCharField(primary_key=True, unique=True, max_length=17) - hostname = pw.CharField(null=True) - last_seen = IsoDateTimeField() - owner = pw.ForeignKeyField( - User, backref="devices", column_name="user_id", null=True - ) - flags = pw.BitField(null=True) - - is_hidden = flags.flag(1) - is_new = flags.flag(2) - is_infrastructure = flags.flag(4) - is_esp = flags.flag(8) - is_laptop = flags.flag(16) - - class Meta: - database = db - - def __str__(self): - return self.mac_address - - @classmethod - def get_recent(cls, days=0, hours=0, minutes=30, seconds=0): - """ - Returns list of last connected devices - :param hours: - :param minutes: - :param seconds: - :return: list of devices - """ - recent_time = datetime.now(timezone.utc) - timedelta( - days=days, hours=hours, minutes=minutes, seconds=seconds - ) - devices = list( - cls.select().where(cls.last_seen > recent_time).order_by(cls.last_seen) - ) - return devices - - @classmethod - def update_or_create(cls, mac_address, last_seen, hostname=None): - try: - res = cls.create( - mac_address=mac_address, hostname=hostname, last_seen=last_seen - ) - - except pw.IntegrityError: - res = cls.get(cls.mac_address == mac_address) - res.last_seen = last_seen - res.hostname = hostname - - res.save() diff --git a/whois/types/bitfield.py b/whois/entity/bitfield.py similarity index 72% rename from whois/types/bitfield.py rename to whois/entity/bitfield.py index 7f0d9d6..dfb6db8 100644 --- a/whois/types/bitfield.py +++ b/whois/entity/bitfield.py @@ -1,14 +1,8 @@ from __future__ import annotations -import sqlalchemy.types as types -from sqlalchemy.ext.mutable import Mutable - -class BitField(Mutable, types.TypeDecorator): - impl = types.INTEGER - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class BitField: + def __init__(self): self._flags = 0 def has_flag(self, mask: int) -> bool: @@ -23,11 +17,21 @@ def unset_flag(self, mask: int) -> None: """Un-set the specifig flag in the bitfield. Can un-set many bits at once""" self._flags &= ~mask + def clear(self) -> None: + """Set all bits of the BitMap to 0's""" + self._flags = 0 + def __int__(self) -> int: return int(self._flags) def __repr__(self) -> str: return f"BitField({bin(self._flags)})" + def __format__(self, format_spec): + return self.__repr__() + def __eq__(self, other: BitField) -> bool: return int(self) == int(other) + + def __hash__(self) -> int: + return hash(self._flags) diff --git a/whois/entity/device.py b/whois/entity/device.py new file mode 100644 index 0000000..69b32d1 --- /dev/null +++ b/whois/entity/device.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from enum import Enum + +from whois.entity.bitfield import BitField +from whois.entity.iso_date_time import IsoDateTimeField + + +class DeviceFlags(Enum): + is_hidden = 1 + is_new = 2 + is_infrastructure = 4 + is_esp = 8 + is_laptop = 16 + + +@dataclass +class Device: + mac_address: str + hostname: str + last_seen: IsoDateTimeField + owner: int + flags: BitField diff --git a/whois/entity/iso_date_time.py b/whois/entity/iso_date_time.py new file mode 100644 index 0000000..0af84d7 --- /dev/null +++ b/whois/entity/iso_date_time.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class IsoDateTimeField: + value: datetime + + @property + def db_value(self) -> str: + return self.value.isoformat() + + # def python_value(self, value: str) -> datetime: + # if value: + # return datetime.fromisoformat(value) diff --git a/whois/entity/user.py b/whois/entity/user.py new file mode 100644 index 0000000..1b0d127 --- /dev/null +++ b/whois/entity/user.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from enum import Enum + +from werkzeug.security import check_password_hash, generate_password_hash + +from whois.entity.bitfield import BitField + + +class UserFlags(Enum): + is_hidden = 1 + is_name_anonymous = 2 + + +@dataclass +class User: + username: str + display_name: str + id: int = None + flags: BitField = BitField() + _password: str = None + + def __str__(self): + if self.is_name_anonymous or self.is_hidden: + return "anonymous" + else: + return self.display_name + + def get_id(self) -> int: + """Get the user id. Required by flask login.""" + return self.id + + @property + def is_hidden(self) -> bool: + return self.flags.has_flag(UserFlags.is_hidden.value) + + @property + def is_name_anonymous(self) -> bool: + return self.flags.has_flag(UserFlags.is_name_anonymous.value) + + @property + def is_active(self): + return self.username is not None + + @property + def is_authenticated(self): + return True + + @property + def is_anonymous(self): + """ + Needed by flask login + :return: + """ + return False + + @property + def is_sso(self) -> bool: + return not self._password + + @property + def password(self) -> str: + return self._password + + @password.setter + def password(self, new_password: str): + if len(new_password) < 3: + raise Exception( + "Password is too short. It should contain at least 3 characters." + ) + else: + self._password = generate_password_hash(new_password) + + def auth(self, password: str) -> bool: + return check_password_hash(self._password, password) diff --git a/whois/helpers.py b/whois/helpers.py index af16206..cda7546 100644 --- a/whois/helpers.py +++ b/whois/helpers.py @@ -1,8 +1,12 @@ import logging +from datetime import datetime, timedelta, timezone from functools import wraps -from urllib.parse import urlparse, urljoin +from typing import List +from urllib.parse import urljoin, urlparse -from flask import request, abort +from flask import abort, request + +from whois.entity.device import Device from whois.settings import ip_mask logging.basicConfig(level=logging.INFO) @@ -73,3 +77,15 @@ def func(*a, **kw): return func return decorator + + +def filter_recent(delta: timedelta, devices: List[Device]): + """ + Returns list of last connected devices + :param hours: + :param minutes: + :param seconds: + :return: list of devices + """ + recent_time = datetime.now(timezone.utc) - delta + return list(filter(lambda device: device.recent_time > recent_time, devices)) diff --git a/whois/settings.py b/whois/settings.py index 3c49619..b10c1aa 100644 --- a/whois/settings.py +++ b/whois/settings.py @@ -45,4 +45,4 @@ # ip_mask = "192.168.88.1-255" ip_mask = os.environ.get("APP_IP_MASK", None) if not ip_mask: - raise ValueError("ERROR: APP_IP_MASK environment variable was not set!") \ No newline at end of file + raise ValueError("ERROR: APP_IP_MASK environment variable was not set!") diff --git a/whois/settings_test.py b/whois/settings_test.py new file mode 100644 index 0000000..7652a50 --- /dev/null +++ b/whois/settings_test.py @@ -0,0 +1,29 @@ +import os + +from pytz import timezone + +APP_VERSION = "1.5.0" +APP_TITLE = "👀 kto hakuje" +APP_NAME = "Kto Hakuje" + +APP_BASE_URL = "whois.at.hsp.sh" + +APP_HOME_URL = "//hsp.sh" +APP_WIKI_URL = "//wiki.hsp.sh/whois" +APP_REPO_URL = "//github.com/hspsh/whohacks" + +APP_TIMEZONE = timezone(os.environ.get("APP_TIMEZONE", "Europe/Warsaw")) + +# mikrtotik ip, or other reporting devices +whitelist = ["192.168.88.1"] +host = "0.0.0.0" +user_flags = {1: "hidden", 2: "name_anonymous"} +device_flags = {1: "hidden", 2: "new", 4: "infrastructure", 8: "esp", 16: "laptop"} + +recent_time = {"minutes": 20} +worker_frequency_s = 60 + +oidc_enabled = True + +SECRET_KEY = "test_key" +ip_mask = "127.0.0.1:5000" diff --git a/whois/web.py b/whois/web.py index 86a6fb2..52b9811 100644 --- a/whois/web.py +++ b/whois/web.py @@ -1,37 +1,22 @@ import logging -import os -from datetime import datetime - -from flask import ( - Flask, - flash, - render_template, - redirect, - url_for, - request, - jsonify, - abort, -) -from flask_cors import CORS -from flask_login import ( - LoginManager, - login_required, - current_user, - login_user, - logout_user, -) +from datetime import datetime, timedelta, timezone + from authlib.integrations.flask_client import OAuth +from flask import (Flask, abort, flash, jsonify, redirect, render_template, + request, url_for) +from flask_cors import CORS +from flask_login import (LoginManager, current_user, login_required, + login_user, logout_user) +from sqlalchemy.orm.exc import NoResultFound from whois import settings -from whois.database import db, Device, User -from whois.helpers import ( - owners_from_devices, - filter_hidden, - unclaimed_devices, - filter_anon_names, - ip_range, - in_space_required, -) +from whois.data.db.database import Database +from whois.data.repository.device_repository import DeviceRepository +from whois.data.repository.user_repository import UserRepository +from whois.entity.user import User, UserFlags +from whois.helpers import (filter_anon_names, filter_hidden, filter_recent, + in_space_required, ip_range, owners_from_devices, + unclaimed_devices) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -40,6 +25,9 @@ app.config.from_object("whois.settings") login_manager = LoginManager() login_manager.init_app(app) +database = Database() +device_repository = DeviceRepository(database) +user_repository = UserRepository(database) if settings.oidc_enabled: oauth = OAuth(app) @@ -63,8 +51,8 @@ def local_time(dt: datetime): @login_manager.user_loader def load_user(user_id): try: - return User.get_by_id(user_id) - except User.DoesNotExist as exc: + return user_repository.get_by_id(user_id) + except NoResultFound as exc: app.logger.error("{}".format(exc)) return None @@ -72,7 +60,7 @@ def load_user(user_id): @app.before_request def before_request(): app.logger.info("connecting to db") - db.connect() + database.connect() if request.headers.getlist("X-Forwarded-For"): ip_addr = request.headers.getlist("X-Forwarded-For")[0] @@ -91,8 +79,12 @@ def before_request(): @app.teardown_appcontext def after_request(error): - app.logger.info("closing db") - db.close() + if database.is_connected: + app.logger.info("Closing the database connection") + database.disconnect() + else: + app.logger.info("Database connection was already closed") + if error: app.logger.error(error) @@ -100,7 +92,8 @@ def after_request(error): @app.route("/") def index(): """Serve list of people in hs, show panel for logged users""" - recent = Device.get_recent(**settings.recent_time) + devices = device_repository.get_all() + recent = filter_recent(timedelta(**settings.recent_time), devices) visible_devices = filter_hidden(recent) users = filter_hidden(owners_from_devices(visible_devices)) @@ -116,13 +109,14 @@ def index(): @login_required @app.route("/devices") def devices(): - recent = Device.get_recent(**settings.recent_time) + devices = device_repository.get_all() + recent = filter_recent(timedelta(**settings.recent_time), devices) visible_devices = filter_hidden(recent) users = filter_hidden(owners_from_devices(visible_devices)) if current_user.is_authenticated: unclaimed = unclaimed_devices(recent) - mine = current_user.devices + mine = filter(lambda device: device.owner == current_user.get_id(), devices) return render_template( "devices.html", unclaimed=unclaimed, @@ -147,13 +141,14 @@ def now_at_space(): if key in request.args: period[key] = request.args.get(key, default=0, type=int) - devices = filter_hidden(Device.get_recent(**period)) - users = filter_hidden(owners_from_devices(devices)) + devices = device_repository.get_all() + recent = filter_recent(timedelta(**settings.recent_time), devices) + users = filter_hidden(owners_from_devices(recent)) data = { "users": sorted(map(str, filter_anon_names(users))), "headcount": len(users), - "unknown_devices": len(unclaimed_devices(devices)), + "unknown_devices": len(unclaimed_devices(recent)), } app.logger.info("sending request for /api/now {}".format(data)) @@ -186,8 +181,8 @@ def device_view(mac_address): """Get info about device, claim device, release device""" try: - device = Device.get(Device.mac_address == mac_address) - except Device.DoesNotExist as exc: + device = device_repository.get_by_mac_address(mac_address) + except NoResultFound as exc: app.logger.error("{}".format(exc)) return abort(404) @@ -243,14 +238,13 @@ def register(): password = request.form["password"] try: - user = User.register(username, password, display_name) + user = register(username, password, display_name) except Exception as exc: if exc.args[0] == "too_short": flash("Password too short, minimum length is 3") else: print(exc) else: - user.save() app.logger.info("registered new user: {}".format(user.username)) flash("Registered.", "info") @@ -269,16 +263,17 @@ def login(): if request.method == "POST": try: - user = User.get(User.username == request.form["username"]) - except User.DoesNotExist: + username = request.form["username"] + user = user_repository.get_by_username(username) + except NoResultFound: user = None - if user is not None: + if user: if user.is_sso: # User created via sso -> redirect to sso login app.logger.info("Redirect to SSO user: {}".format(user.username)) - return redirect(url_for("login_oauth")) - elif user.auth(request.form["password"]) is True: + return redirect(url_for("login_oauth")) + elif user.auth(request.form["password"]): # User password hash match -> login user successfully login_user(user) app.logger.info("logged in: {}".format(user.username)) @@ -314,14 +309,13 @@ def callback(): user_info = oauth.sso.parse_id_token(token) if user_info: try: - user = User.get(User.username == user_info["preferred_username"]) - except User.DoesNotExist: + user = user_repository.get_by_username(user_info["preferred_username"]) + except NoResultFound: username = user_info["preferred_username"] app.logger.info( - f"No SSO-loggined user: {username}.\n" - f"Register user {username}", + f"No SSO-loggined user: {username}.\n" f"Register user {username}", ) - user = User.register_from_sso(username=username, display_name=username) + user = register_from_sso(username=username, display_name=username) if user is not None: login_user(user) @@ -369,14 +363,49 @@ def profile_edit(): else: current_user.display_name = request.form["display_name"] new_flags = request.form.getlist("flags") - current_user.is_hidden = "hidden" in new_flags - current_user.is_name_anonymous = "anonymous for public" in new_flags + if "hidden" in new_flags: + current_user.flags.set_flag(UserFlags.is_hidden.value) + else: + current_user.flags.unset_flag(UserFlags.is_hidden.value) + + if "anonymous for public" in new_flags: + current_user.flags.set_flag(UserFlags.is_name_anonymous.value) + else: + current_user.flags.unset_flag(UserFlags.is_name_anonymous.value) + app.logger.info( "flags: got {} set {:b}".format(new_flags, current_user.flags) ) - current_user.save() + user_repository.update(current_user) + flash("Saved", "success") else: flash("Invalid password", "error") return render_template("profile.html", user=current_user, **common_vars_tpl) + + +def register(username, password, display_name=None): + """ + Creates user and hashes his password + :param username: used in login + :param password: plain text to be hashed + :param display_name: displayed username + :return: user instance + """ + user = User(username=username, display_name=display_name) + user.password = password + user_repository.insert(user) + return user + + +def register_from_sso(username, display_name=None): + """ + Creates user without any password. Such users can only login via SSO. + :param username: used in login + :param display_name: displayed username + :return: user instance + """ + user = User(username=username, display_name=display_name) + user_repository.insert(user) + return user diff --git a/whois/worker.py b/whois/worker.py index 87bd290..e4b99e1 100644 --- a/whois/worker.py +++ b/whois/worker.py @@ -1,9 +1,9 @@ -from datetime import datetime, timezone import logging import time +from datetime import datetime, timezone from whois import settings -from whois.database import db, Device +from whois.data.db.database import Device, db from whois.mikrotik import fetch_leases logger = logging.getLogger("mikrotik-worker")