Skip to content

Commit

Permalink
fix: mysql + postgres integrations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Jan 3, 2024
1 parent a1d4fab commit 6fab2e6
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
4 changes: 2 additions & 2 deletions backend/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class User(BaseModel):
def create_user(user: User) -> None:
with Database() as connection:
connection.execute(
"INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)
"INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.password)
)


def get_user(email: str) -> User:
with Database() as connection:
user_row = connection.execute("SELECT * FROM user WHERE email = ?", (email,))
user_row = connection.execute("SELECT * FROM users WHERE email = ?", (email,))
for row in user_row:
return User(**row)
raise Exception("User not found")
Expand Down
38 changes: 23 additions & 15 deletions backend/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,28 @@
import sqlglot
from dbutils.pooled_db import PooledDB
from logging import Logger
from sqlalchemy.engine.url import make_url

from backend.logger import get_logger

POOL = None
class Database:
DIALECT_PLACEHOLDERS = {
"sqlite": "?",
"postgresql": "%s",
"mysql": "%s",
}

def __init__(self, connection_string: str = None, logger: Logger = None):
self.connection_string = connection_string or os.getenv("DATABASE_URL")
self.logger = logger or get_logger()

self.url = make_url(self.connection_string)

self.logger.debug("Creating connection pool")
self.pool = self._create_pool()
global POOL
POOL = POOL or self._create_pool() # Makes the pool a singleton
self.pool = POOL
self.conn = None

def __enter__(self) -> "Database":
Expand All @@ -38,6 +50,7 @@ def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException],
def execute(self, query: str, params: Optional[tuple] = None) -> Any:
cursor = self.conn.cursor()
try:
query = query.replace("?", self.DIALECT_PLACEHOLDERS[self.url.drivername])
self.logger.debug(f"Executing query: {query}")
cursor.execute(query, params or ())
return cursor
Expand All @@ -64,10 +77,10 @@ def initialize_schema(self):
try:
self.logger.debug("Initializing database schema")
sql_script = Path(__file__).parent.joinpath('db_init.sql').read_text()
transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.connection_string.split(":")[0])
transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.url.drivername.replace("postgresql", "postgres"))
for statement in transpiled_sql:
self.execute(statement)
self.logger.debug(f"Database schema initialized successfully for {self.connection_string.split(':')[0]}")
self.logger.info(f"Database schema initialized successfully for {self.url.drivername}")
except Exception as e:
self.logger.exception("Schema initialization failed", exc_info=e)
raise
Expand All @@ -77,22 +90,17 @@ def _create_pool(self) -> PooledDB:
import sqlite3
Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True)
return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5)
elif self.connection_string.startswith("postgres://"):
elif self.connection_string.startswith("postgresql://"):
import psycopg2
return PooledDB(creator=psycopg2, dsn=self.connection_string.replace("postgres://", ""), maxconnections=5)
return PooledDB(creator=psycopg2, dsn=self.connection_string, maxconnections=5)
elif self.connection_string.startswith("mysql://"):
import mysql.connector
return PooledDB(creator=mysql.connector, database=self.connection_string.replace("mysql://", ""), maxconnections=5)
return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5)
elif self.connection_string.startswith("mysql+pymysql://"):
import mysql.connector
return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5)
elif self.connection_string.startswith("sqlserver://"):
import pyodbc
return PooledDB(creator=pyodbc, dsn=self.connection_string.replace("sqlserver://", ""), maxconnections=5)
else:
raise ValueError("Unsupported database type")



if __name__ == "__main__":
load_dotenv()
with Database(os.getenv("DATABASE_URL")) as db:
db.execute("DELETE FROM user WHERE email IN ('alexis')")
db.execute("DELETE FROM chat WHERE user_id IN ('alexis')")
raise ValueError(f"Unsupported database type: {self.url.drivername}")
18 changes: 9 additions & 9 deletions backend/db_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
-- Paste here
-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS"

CREATE TABLE IF NOT EXISTS "user" (
"email" TEXT PRIMARY KEY,
CREATE TABLE IF NOT EXISTS "users" (
"email" VARCHAR(255) PRIMARY KEY,
"password" TEXT
);

CREATE TABLE IF NOT EXISTS "chat" (
"id" TEXT PRIMARY KEY,
"id" VARCHAR(255) PRIMARY KEY,
"timestamp" DATETIME,
"user_id" TEXT,
FOREIGN KEY ("user_id") REFERENCES "user" ("email")
"user_id" VARCHAR(255),
FOREIGN KEY ("user_id") REFERENCES "users" ("email")
);

CREATE TABLE IF NOT EXISTS "message" (
"id" TEXT PRIMARY KEY,
"id" VARCHAR(255) PRIMARY KEY,
"timestamp" DATETIME,
"chat_id" TEXT,
"chat_id" VARCHAR(255),
"sender" TEXT,
"content" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

CREATE TABLE IF NOT EXISTS "feedback" (
"id" TEXT PRIMARY KEY,
"message_id" TEXT,
"id" VARCHAR(255) PRIMARY KEY,
"message_id" VARCHAR(255),
"feedback" TEXT,
FOREIGN KEY ("message_id") REFERENCES "message" ("id")
);
8 changes: 4 additions & 4 deletions backend/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ class User(BaseModel):
def create_user(user: User) -> None:
with Database() as connection:
connection.execute(
"INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)
"INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.password)
)


def user_exists(email: str) -> bool:
with Database() as connection:
result = connection.fetchone("SELECT 1 FROM user WHERE email = ?", (email,))
result = connection.fetchone("SELECT 1 FROM users WHERE email = ?", (email,))
return bool(result)


def get_user(email: str) -> Optional[User]:
with Database() as connection:
user_row = connection.fetchone("SELECT * FROM user WHERE email = ?", (email,))
user_row = connection.fetchone("SELECT * FROM users WHERE email = ?", (email,))
if user_row:
return User(email=user_row[0], password=user_row[1])
return None


def delete_user(email: str) -> None:
with Database() as connection:
connection.execute("DELETE FROM user WHERE email = ?", (email,))
connection.execute("DELETE FROM users WHERE email = ?", (email,))


def authenticate_user(username: str, password: str) -> Optional[User]:
Expand Down

0 comments on commit 6fab2e6

Please sign in to comment.