-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor storage, add embedding models
- Loading branch information
1 parent
0564066
commit 1b1dc90
Showing
23 changed files
with
250 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ file_store | |
storage | ||
worker_tmp | ||
temp_zip | ||
*.db | ||
|
||
node_modules | ||
.pnp | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from uvicorn import run | ||
from router.storage import router as storage_router | ||
from router.task import router as task_router | ||
from router.token import router as token_router | ||
from shared.settings import UVICORN_CONF, SELF_ORIGIN | ||
from shared.setup import AuthMiddleware, lifespan | ||
|
||
app: FastAPI = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) | ||
|
||
app.add_middleware(CORSMiddleware, allow_origins=SELF_ORIGIN) | ||
app.add_middleware(AuthMiddleware) | ||
|
||
app.include_router(storage_router) | ||
app.include_router(task_router) | ||
app.include_router(token_router) | ||
|
||
if __name__ == "__main__": run(**UVICORN_CONF) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from sqlite3 import Connection, Cursor | ||
from sqlite_vec import load as load_vec_module | ||
from typing_extensions import Any, Callable | ||
from numpy import ndarray | ||
from enum import Enum | ||
from .settings import HASH_SIZE, EMBEDDING_STORAGE_PATH | ||
|
||
|
||
class Query(Enum): | ||
MIGRATE = "create table if not exists storage using vec0(embedding float[{}]);" | ||
INSERT = "insert into storage (rowid, embedding) values (?, ?);" | ||
SELECT = """ | ||
select rowid, distance | ||
from storage | ||
where embedding match ? | ||
order by distance | ||
limit ?; | ||
""" | ||
|
||
|
||
class EmdeddingStorage: | ||
__slots__ = ("conn", "corrupted", "reason") | ||
_k_nearest = 3 | ||
|
||
def __init__(self): | ||
self.corrupted = False | ||
self.conn = Connection(EMBEDDING_STORAGE_PATH) | ||
self.load_module() | ||
|
||
def load_module(self): | ||
self.conn.enable_load_extension(True) | ||
load_vec_module(self.conn) | ||
self.conn.enable_load_extension(False) | ||
|
||
def __enter__(self) -> "EmdeddingStorage": return self | ||
|
||
def __exit__(self, *args, **kwargs): | ||
try: | ||
assert not self.corrupted, "Transaction corrupted: " | ||
self.conn.commit() | ||
|
||
except Exception as e: | ||
self.conn.rollback() | ||
raise ValueError(str(e) + self.reason) | ||
|
||
finally: self.conn.close() | ||
|
||
@staticmethod | ||
def with_transaction(callback: Callable) -> Callable: | ||
def inner(self, *args, **kwargs) -> Any: | ||
if self.corrupted: return self.reason | ||
|
||
cursor = self.connection.cursor() | ||
|
||
try: | ||
assert not self.corrupted, "Transaction corrupted" | ||
return callback(self, cursor, *args, **kwargs) | ||
|
||
except Exception as e: | ||
self.reason = str(e) | ||
self.corrupted = True | ||
|
||
finally: cursor.close() | ||
|
||
return inner | ||
|
||
@with_transaction | ||
def migrate(self, cur: Cursor): cur.execute(Query.MIGRATE.value.format(HASH_SIZE**2)) | ||
|
||
@with_transaction | ||
def insert(self, cur: Cursor, file_id: str, embedding: ndarray): | ||
cur.execute(Query.SELECT.value, [file_id, embedding]) | ||
|
||
@with_transaction | ||
def select(self, cur: Cursor, embedding: ndarray): | ||
result = cur.execute(Query.SELECT.value, [embedding, self._k_nearest]).fetchall() | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.