diff --git a/backend/authentication.py b/backend/authentication.py index 79f467b..80c3401 100644 --- a/backend/authentication.py +++ b/backend/authentication.py @@ -19,13 +19,13 @@ class User(BaseModel): def create_user(user: User) -> None: with Database() as connection: connection.execute( - "INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password) + "INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.password) ) def get_user(email: str) -> User: with Database() as connection: - user_row = connection.execute("SELECT * FROM user WHERE email = ?", (email,)) + user_row = connection.execute("SELECT * FROM users WHERE email = ?", (email,)) for row in user_row: return User(**row) raise Exception("User not found") diff --git a/backend/database.py b/backend/database.py index 6edd3a2..8d80fd4 100644 --- a/backend/database.py +++ b/backend/database.py @@ -7,16 +7,28 @@ import sqlglot from dbutils.pooled_db import PooledDB from logging import Logger +from sqlalchemy.engine.url import make_url from backend.logger import get_logger +POOL = None class Database: + DIALECT_PLACEHOLDERS = { + "sqlite": "?", + "postgresql": "%s", + "mysql": "%s", + } + def __init__(self, connection_string: str = None, logger: Logger = None): self.connection_string = connection_string or os.getenv("DATABASE_URL") self.logger = logger or get_logger() + self.url = make_url(self.connection_string) + self.logger.debug("Creating connection pool") - self.pool = self._create_pool() + global POOL + POOL = POOL or self._create_pool() # Makes the pool a singleton + self.pool = POOL self.conn = None def __enter__(self) -> "Database": @@ -38,6 +50,7 @@ def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], def execute(self, query: str, params: Optional[tuple] = None) -> Any: cursor = self.conn.cursor() try: + query = query.replace("?", self.DIALECT_PLACEHOLDERS[self.url.drivername]) self.logger.debug(f"Executing query: {query}") cursor.execute(query, params or ()) return cursor @@ -64,10 +77,10 @@ def initialize_schema(self): try: self.logger.debug("Initializing database schema") sql_script = Path(__file__).parent.joinpath('db_init.sql').read_text() - transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.connection_string.split(":")[0]) + transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.url.drivername.replace("postgresql", "postgres")) for statement in transpiled_sql: self.execute(statement) - self.logger.debug(f"Database schema initialized successfully for {self.connection_string.split(':')[0]}") + self.logger.info(f"Database schema initialized successfully for {self.url.drivername}") except Exception as e: self.logger.exception("Schema initialization failed", exc_info=e) raise @@ -77,22 +90,17 @@ def _create_pool(self) -> PooledDB: import sqlite3 Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True) return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5) - elif self.connection_string.startswith("postgres://"): + elif self.connection_string.startswith("postgresql://"): import psycopg2 - return PooledDB(creator=psycopg2, dsn=self.connection_string.replace("postgres://", ""), maxconnections=5) + return PooledDB(creator=psycopg2, dsn=self.connection_string, maxconnections=5) elif self.connection_string.startswith("mysql://"): import mysql.connector - return PooledDB(creator=mysql.connector, database=self.connection_string.replace("mysql://", ""), maxconnections=5) + return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5) + elif self.connection_string.startswith("mysql+pymysql://"): + import mysql.connector + return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5) elif self.connection_string.startswith("sqlserver://"): import pyodbc return PooledDB(creator=pyodbc, dsn=self.connection_string.replace("sqlserver://", ""), maxconnections=5) else: - raise ValueError("Unsupported database type") - - - -if __name__ == "__main__": - load_dotenv() - with Database(os.getenv("DATABASE_URL")) as db: - db.execute("DELETE FROM user WHERE email IN ('alexis')") - db.execute("DELETE FROM chat WHERE user_id IN ('alexis')") + raise ValueError(f"Unsupported database type: {self.url.drivername}") diff --git a/backend/db_init.sql b/backend/db_init.sql index 8c8e472..4258b3a 100644 --- a/backend/db_init.sql +++ b/backend/db_init.sql @@ -4,30 +4,30 @@ -- Paste here -- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS" -CREATE TABLE IF NOT EXISTS "user" ( - "email" TEXT PRIMARY KEY, +CREATE TABLE IF NOT EXISTS "users" ( + "email" VARCHAR(255) PRIMARY KEY, "password" TEXT ); CREATE TABLE IF NOT EXISTS "chat" ( - "id" TEXT PRIMARY KEY, + "id" VARCHAR(255) PRIMARY KEY, "timestamp" DATETIME, - "user_id" TEXT, - FOREIGN KEY ("user_id") REFERENCES "user" ("email") + "user_id" VARCHAR(255), + FOREIGN KEY ("user_id") REFERENCES "users" ("email") ); CREATE TABLE IF NOT EXISTS "message" ( - "id" TEXT PRIMARY KEY, + "id" VARCHAR(255) PRIMARY KEY, "timestamp" DATETIME, - "chat_id" TEXT, + "chat_id" VARCHAR(255), "sender" TEXT, "content" TEXT, FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") ); CREATE TABLE IF NOT EXISTS "feedback" ( - "id" TEXT PRIMARY KEY, - "message_id" TEXT, + "id" VARCHAR(255) PRIMARY KEY, + "message_id" VARCHAR(255), "feedback" TEXT, FOREIGN KEY ("message_id") REFERENCES "message" ("id") ); diff --git a/backend/user_management.py b/backend/user_management.py index 2cca3fd..f2a8f9d 100644 --- a/backend/user_management.py +++ b/backend/user_management.py @@ -19,19 +19,19 @@ class User(BaseModel): def create_user(user: User) -> None: with Database() as connection: connection.execute( - "INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password) + "INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.password) ) def user_exists(email: str) -> bool: with Database() as connection: - result = connection.fetchone("SELECT 1 FROM user WHERE email = ?", (email,)) + result = connection.fetchone("SELECT 1 FROM users WHERE email = ?", (email,)) return bool(result) def get_user(email: str) -> Optional[User]: with Database() as connection: - user_row = connection.fetchone("SELECT * FROM user WHERE email = ?", (email,)) + user_row = connection.fetchone("SELECT * FROM users WHERE email = ?", (email,)) if user_row: return User(email=user_row[0], password=user_row[1]) return None @@ -39,7 +39,7 @@ def get_user(email: str) -> Optional[User]: def delete_user(email: str) -> None: with Database() as connection: - connection.execute("DELETE FROM user WHERE email = ?", (email,)) + connection.execute("DELETE FROM users WHERE email = ?", (email,)) def authenticate_user(username: str, password: str) -> Optional[User]: