Skip to content

Commit

Permalink
rewrite how validation is done
Browse files Browse the repository at this point in the history
  • Loading branch information
shapiromatron committed Jul 10, 2024
1 parent 55476ce commit f7270c2
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 51 deletions.
119 changes: 74 additions & 45 deletions bmds_ui/desktop/components/database_form.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sqlite3
import tempfile
from pathlib import Path
from typing import Any

from pydantic import ValidationError
from textual import on, work
from textual.app import ComposeResult
from textual.containers import Grid, Horizontal
Expand All @@ -11,8 +14,9 @@
from textual.widgets import Button, Input, Label, Markdown

from ..actions import create_django_db
from ..config import Config, Database
from ..config import Config, Database, db_suffixes
from ..log import log
from .utils import get_error_string


def str_exists(value: str):
Expand All @@ -27,18 +31,40 @@ def path_exists(value: str):


def file_valid(value: str):
if not value.endswith(".sqlite3") and not value.endswith(".sqlite"):
return False
return True


def check_permission(path: str, db: str):
db_path = (Path(path).expanduser().resolve() / db).absolute()
try:
db_path.touch()
except PermissionError:
return False
return True
return any(value.endswith(suffix) for suffix in db_suffixes)


def additional_path_checks(db: Database):
# Additional path checks. We don't add to the pydantic model validation because we don't
# want to do this with every pydantic database model in config; but we do want these checks
# when we create or update our configuration file.

# create parent path if it doesn't already exist
if not db.path.parent.exists():
try:
db.path.parent.mkdir(parents=True)
except Exception:
raise ValueError(f"Cannot create path {db.path.parent}")

# check path is writable
if not db.path.exists():
try:
with tempfile.NamedTemporaryFile(dir=db.path.parent, delete=True, mode="w") as f:
f.write("test")
f.flush()
except Exception:
raise ValueError(f"Cannot write to {db.path.parent}")

# check existing database is loadable and writeable
if db.path.exists():
try:
conn = sqlite3.connect(db.path)
cursor = conn.cursor()
cursor.execute("CREATE TEMP TABLE test_writable (id INTEGER)")
conn.commit()
conn.close()
except (sqlite3.DatabaseError, sqlite3.OperationalError):
raise ValueError(f"Cannot edit database {db.path}. Is this a sqlite database?")


class NullWidget(Widget):
Expand All @@ -53,10 +79,19 @@ def compose(self) -> ComposeResult:


class FormError(Widget):
DEFAULT_CSS = """
.has-error {
background: $error;
color: white;
width: 100%;
padding: 0 3;
}
"""

message = reactive("", recompose=True)

def compose(self) -> ComposeResult:
yield Label(self.message)
yield Label(self.message, expand=True, classes="has-error" if len(self.message) > 0 else "")


class DatabaseFormModel(ModalScreen):
Expand Down Expand Up @@ -97,7 +132,7 @@ class DatabaseFormModel(ModalScreen):
"""

def __init__(self, *args, db: Database | None, **kw):
self.db = db
self.db: Database | None = db
super().__init__(*args, **kw)

def get_db_value(self, attr: str, default: Any):
Expand Down Expand Up @@ -128,16 +163,16 @@ def compose(self) -> ComposeResult:
id="path",
validators=[Function(path_exists)],
),
Label("Filename (*.sqlite)"),
Label("Filename (*.db)"),
Input(
value=path.name if path else "db.sqlite",
value=path.name if path else "bmds-database.db",
type="text",
id="filename",
validators=[Function(file_valid)],
),
Label("Description"),
Input(value=self.get_db_value("description", ""), type="text", id="description"),
FormError(classes="span4 error-text"),
FormError(classes="span4"),
Horizontal(
save_btn,
Button("Cancel", variant="default", id="db-edit-cancel"),
Expand All @@ -148,20 +183,24 @@ def compose(self) -> ComposeResult:
id="grid-db-form",
)

def db_valid(self) -> Database:
db = Database(
name=self.query_one("#name").value,
description=self.query_one("#description").value,
path=Path(self.query_one("#path").value) / self.query_one("#filename").value,
)
additional_path_checks(db)
return db

@on(Button.Pressed, "#db-create")
async def on_db_create(self) -> None:
name = self.query_one("#name").value
path = self.query_one("#path").value
db = self.query_one("#filename").value
description = self.query_one("#description").value
if not all(
(str_exists(name), path_exists(path), file_valid(db), check_permission(path, db))
):
self.query_one(FormError).message = "An error occurred."
try:
db = self.db_valid()
except (ValidationError, ValueError) as err:
self.query_one(FormError).message = get_error_string(err)
return
db_path = (Path(path).expanduser().resolve() / db).absolute()

config = Config.get()
db = Database(name=name, description=description, path=db_path)
self._create_django_db(config, db)

@work(exclusive=True, thread=True)
Expand All @@ -176,25 +215,15 @@ def _create_django_db(self, config, db):

@on(Button.Pressed, "#db-update")
async def on_db_update(self) -> None:
name = self.query_one("#name").value
path = self.query_one("#path").value
db = self.query_one("#filename").value
description = self.query_one("#description").value
if not all(
(str_exists(name), path_exists(path), file_valid(db), check_permission(path, db))
):
self.query_one(FormError).message = "An error occurred."
return

db_path = (Path(path).expanduser().resolve() / db).absolute()
if not db_path.exists():
message = f"Database does not exist: {db_path}"
self.query_one(FormError).message = message
try:
db = self.db_valid()
except (ValidationError, ValueError) as err:
self.query_one(FormError).message = get_error_string(err)
return

self.db.name = name
self.db.path = Path(path) / db
self.db.description = description
self.db.name = db.name
self.db.path = db.path
self.db.description = db.description
Config.sync()
log.info(f"Config updated for {self.db}")
self.dismiss(True)
Expand Down
4 changes: 0 additions & 4 deletions bmds_ui/desktop/components/style.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,3 @@ Footer {
.span4 {
column-span: 4
}

.error-text {
color: $error;
}
7 changes: 7 additions & 0 deletions bmds_ui/desktop/components/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from pydantic import ValidationError
from textual.app import App


def refresh(refresh: bool, app: App):
if refresh:
app.query_one("DatabaseList").refresh(layout=True, recompose=True)


def get_error_string(err: Exception) -> str:
if isinstance(err, ValidationError):
return "\n".join(f"{e['loc'][0]}: {e['msg']}" for e in err.errors())
return str(err)
21 changes: 19 additions & 2 deletions bmds_ui/desktop/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import ClassVar, Self
from uuid import UUID, uuid4

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator

from .. import __version__

Expand All @@ -15,20 +15,37 @@ def now() -> datetime:
return datetime.now(tz=UTC)


db_suffixes = (".db", ".sqlite", ".sqlite3")


class Database(BaseModel):
id: UUID = Field(default_factory=uuid4)
name: str = ""
name: str = Field(min_length=1)
description: str = ""
path: Path
created: datetime = Field(default_factory=now)
last_accessed: datetime = Field(default_factory=now)

model_config = ConfigDict(str_strip_whitespace=True)

def __str__(self) -> str:
return f"{self.name}: {self.path}"

def update_last_accessed(self):
self.last_accessed = datetime.now(tz=UTC)

@field_validator("path")
@classmethod
def path_check(cls, path: Path):
# resolve the fully normalized path
path = path.expanduser().resolve()

# check suffix
if path.suffix not in db_suffixes:
raise ValueError('Filename must end with the "sqlite" extension')

return path


class WebServer(BaseModel):
host: str = "127.0.0.1"
Expand Down

0 comments on commit f7270c2

Please sign in to comment.