Skip to content

Commit

Permalink
refactor storage, add embedding models
Browse files Browse the repository at this point in the history
  • Loading branch information
githubering182 committed Aug 14, 2024
1 parent 0564066 commit 1b1dc90
Show file tree
Hide file tree
Showing 23 changed files with 250 additions and 134 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ file_store
storage
worker_tmp
temp_zip
*.db

node_modules
.pnp
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.backend
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11-slim-bookworm
FROM python:3.12-slim-bookworm

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.storage
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11-slim-bookworm
FROM python:3.12-slim-bookworm

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.tests
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11-slim-bookworm
FROM python:3.12-slim-bookworm

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ services:
depends_on:
- iss-test-storage-db
- iss-test-back
command: python3 src/app.py
command: python3 src/main.py
networks:
- iss_test_network

Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ services:
depends_on:
- iss-storage-db
- iss-back
command: python3 src/app.py
command: python3 src/main.py
networks:
- iss_network

Expand Down
2 changes: 1 addition & 1 deletion scripts/file_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from math import ceil
from typing import Any
from threading import Thread, Lock
from shared.db_manager import DataBase, GridFSBucket
from shared.storage_db import DataBase, GridFSBucket
from shared.utils import (
APP_BACKEND_URL,
SECRET_KEY,
Expand Down
4 changes: 2 additions & 2 deletions storage-app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ click-plugins==1.1.1
click-repl==0.3.0
dnspython==2.6.1
exceptiongroup==1.1.3
fastapi==0.103.0
fastapi==0.111.0
flower==2.0.1
h11==0.14.0
humanize==4.8.0
Expand Down Expand Up @@ -45,7 +45,6 @@ python-jose==3.3.0
rsa==4.9
six==1.16.0
certifi==2023.7.22
charset-normalizer==3.2.0
requests==2.31.0
urllib3==2.0.5
httpcore==1.0.4
Expand All @@ -64,3 +63,4 @@ requests==2.32.3
urllib3==2.2.2
videohash==3.0.1
yt-dlp==2024.7.25
sqlite-vec==0.1.0
57 changes: 0 additions & 57 deletions storage-app/src/app.py

This file was deleted.

19 changes: 19 additions & 0 deletions storage-app/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from uvicorn import run
from router.storage import router as storage_router
from router.task import router as task_router
from router.token import router as token_router
from shared.settings import UVICORN_CONF, SELF_ORIGIN
from shared.setup import AuthMiddleware, lifespan

app: FastAPI = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)

app.add_middleware(CORSMiddleware, allow_origins=SELF_ORIGIN)
app.add_middleware(AuthMiddleware)

app.include_router(storage_router)
app.include_router(task_router)
app.include_router(token_router)

if __name__ == "__main__": run(**UVICORN_CONF)
2 changes: 1 addition & 1 deletion storage-app/src/router/router_tests/storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from asyncio import new_event_loop, set_event_loop, get_event_loop
from os.path import abspath
from sys import path
from shared.db_manager import (
from shared.storage_db import (
DataBase,
get_db_uri,
AsyncIOMotorClient,
Expand Down
8 changes: 4 additions & 4 deletions storage-app/src/router/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from shared.models import UploadFile, Form, Annotated
from shared.app_services import Bucket, ObjectStreaming

router = APIRouter()
router = APIRouter(prefix="/api/storage")


@router.get("/api/storage/{bucket_name}/{file_id}/")
@router.get("/{bucket_name}/{file_id}/")
async def get_file(
request: Request,
bucket_name: str,
Expand All @@ -21,7 +21,7 @@ async def get_file(
)


@router.post("/api/storage/{bucket_name}/")
@router.post("/{bucket_name}/")
async def upload_file(
bucket_name: str,
file: UploadFile,
Expand All @@ -33,7 +33,7 @@ async def upload_file(
return JSONResponse(status_code=status, content={"result": result})


@router.delete("/api/storage/{bucket_name}/{file_id}/")
@router.delete("/{bucket_name}/{file_id}/")
async def delete_file(bucket_name: str, file_id: str) -> JSONResponse:
project_bucket: Bucket = Bucket(bucket_name)
result, message = await project_bucket.delete_object(file_id)
Expand Down
6 changes: 3 additions & 3 deletions storage-app/src/router/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from worker import produce_download_task, WORKER
from celery.result import AsyncResult

router = APIRouter()
router = APIRouter(prefix="/api/task")


@router.post("/api/task/archive/")
@router.post("/archive/")
def archive_bucket(request_task: ArchiveTask) -> JSONResponse:
task: AsyncResult = produce_download_task.delay(
request_task.bucket_name,
Expand All @@ -19,7 +19,7 @@ def archive_bucket(request_task: ArchiveTask) -> JSONResponse:


# TODO: find the way to check if no such task
@router.get("/api/task/{task_id}/")
@router.get("/{task_id}/")
def check_task_status(task_id: str) -> JSONResponse:
task: AsyncResult = WORKER.AsyncResult(task_id)
response: dict[str, Any] = {"task_id": task_id, "status": task.status}
Expand Down
4 changes: 2 additions & 2 deletions storage-app/src/router/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from shared.settings import SECRET_KEY, SECRET_ALGO


router = APIRouter()
router = APIRouter(prefix="/api/temp_token")


@router.get("/api/temp_token/")
@router.get("/")
def get_temp_token() -> str:
return emit_token({"minutes": 5}, SECRET_KEY, SECRET_ALGO)
2 changes: 1 addition & 1 deletion storage-app/src/shared/app_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shared.settings import CHUNK_SIZE
from shared.utils import get_object_id
from shared.models import UploadFile
from shared.db_manager import DataBase, AsyncIOMotorGridFSBucket
from shared.storage_db import DataBase, AsyncIOMotorGridFSBucket
from hashlib import md5
from motor.motor_asyncio import (
AsyncIOMotorGridOutCursor,
Expand Down
77 changes: 77 additions & 0 deletions storage-app/src/shared/embedding_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from sqlite3 import Connection, Cursor
from sqlite_vec import load as load_vec_module
from typing_extensions import Any, Callable
from numpy import ndarray
from enum import Enum
from .settings import HASH_SIZE, EMBEDDING_STORAGE_PATH


class Query(Enum):
MIGRATE = "create table if not exists storage using vec0(embedding float[{}]);"
INSERT = "insert into storage (rowid, embedding) values (?, ?);"
SELECT = """
select rowid, distance
from storage
where embedding match ?
order by distance
limit ?;
"""


class EmdeddingStorage:
__slots__ = ("conn", "corrupted", "reason")
_k_nearest = 3

def __init__(self):
self.corrupted = False
self.conn = Connection(EMBEDDING_STORAGE_PATH)
self.load_module()

def load_module(self):
self.conn.enable_load_extension(True)
load_vec_module(self.conn)
self.conn.enable_load_extension(False)

def __enter__(self) -> "EmdeddingStorage": return self

def __exit__(self, *args, **kwargs):
try:
assert not self.corrupted, "Transaction corrupted: "
self.conn.commit()

except Exception as e:
self.conn.rollback()
raise ValueError(str(e) + self.reason)

finally: self.conn.close()

@staticmethod
def with_transaction(callback: Callable) -> Callable:
def inner(self, *args, **kwargs) -> Any:
if self.corrupted: return self.reason

cursor = self.connection.cursor()

try:
assert not self.corrupted, "Transaction corrupted"
return callback(self, cursor, *args, **kwargs)

except Exception as e:
self.reason = str(e)
self.corrupted = True

finally: cursor.close()

return inner

@with_transaction
def migrate(self, cur: Cursor): cur.execute(Query.MIGRATE.value.format(HASH_SIZE**2))

@with_transaction
def insert(self, cur: Cursor, file_id: str, embedding: ndarray):
cur.execute(Query.SELECT.value, [file_id, embedding])

@with_transaction
def select(self, cur: Cursor, embedding: ndarray):
result = cur.execute(Query.SELECT.value, [embedding, self._k_nearest]).fetchall()
return result
43 changes: 36 additions & 7 deletions storage-app/src/shared/hasher.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,57 @@
from imagehash import whash, ImageHash
from videohash import VideoHash
from videohash.utils import (
create_and_return_temporary_directory as mk_temp_dir,
does_path_exists
)
from PIL import Image
from PIL.ImageFile import ImageFile
from os.path import join, sep
from pathlib import Path
from asyncio import get_event_loop
from motor.motor_asyncio import AsyncIOMotorGridOut
from shared.settings import HASH_SIZE, TEMP_HASH_PATH
from numpy import asarray, float32, ndarray
from io import BytesIO

Image.ANTIALIAS = Image.Resampling.LANCZOS


class VHashPatch(VideoHash):
hash: ImageHash
def to_embedding(image: ImageFile) -> ndarray:
return asarray(
image
.convert("L")
.resize((HASH_SIZE, HASH_SIZE), Image.ANTIALIAS)
).flatten().astype(float32)


class IHash:
embedding: ndarray
_file: AsyncIOMotorGridOut

def __init__(self, file: AsyncIOMotorGridOut):
self._file = file
self.embedding = self._get_hash()

def _get_hash(self) -> ndarray:
get_event_loop().run_until_complete(self._get_buffer())
image = Image.open(self._buffer)
return to_embedding(image)

async def _get_buffer(self):
self._file.seek(0)
self._buffer = BytesIO(await self._file.read())


class VHash(VideoHash):
embedding: ndarray
_file: AsyncIOMotorGridOut

def __init__(self, *args, **kwargs):
file, *_ = args
def __init__(self, file: AsyncIOMotorGridOut):
self._file = file
super().__init__(*args, **kwargs)
super().__init__(storage_path=TEMP_HASH_PATH)
self.delete_storage_path()

def _calc_hash(self): self.hash = whash(self.image)
def _calc_hash(self): self.embedding = to_embedding(self.image)

def _create_required_dirs_and_check_for_errors(self):
if not self.storage_path: self.storage_path = mk_temp_dir()
Expand Down
Loading

0 comments on commit 1b1dc90

Please sign in to comment.