Skip to content

Commit

Permalink
fixed the issue with model scan and storage;
Browse files Browse the repository at this point in the history
  • Loading branch information
ranjan-stha committed Oct 15, 2024
1 parent 910a088 commit 491843e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
1 change: 1 addition & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RequestSchemaForEmbeddings(BaseModel):

@app.get("/")
async def home():
"""Returns a message"""
return Response(content="Embedding handler using models for texts", status_code=status.HTTP_200_OK)


Expand Down
6 changes: 3 additions & 3 deletions embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentence_transformers import SentenceTransformer
from torch import Tensor

from utils import download_models
from utils import check_models


@dataclass
Expand All @@ -24,8 +24,8 @@ def __post_init__(self):
"""
Post initialization
"""
models_info = download_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=models_info["model_path"])
model_path = check_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=model_path)

def embed_documents(self, texts: list) -> np.ndarray:
"""
Expand Down
38 changes: 24 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,39 @@
logger.setLevel(logging.INFO)


def download_models(sent_embedding_model: str):
def download_model(embedding_model: str, models_path: str):
"""Downloads the model"""
logger.info("Downloading the model")
embedding_model_local_path = snapshot_download(repo_id=embedding_model, cache_dir=models_path)
return embedding_model_local_path


def check_models(sent_embedding_model: str):
"""Check if the model already exists"""
models_path = Path("/opt/models")
models_info_path = models_path / "model_info.json"

if not os.path.exists(models_path):
os.makedirs(models_path)

if not any(os.listdir(models_path)):
logger.info("Downloading the model")
embedding_model_local_path = snapshot_download(repo_id=sent_embedding_model, cache_dir=models_path)
embedding_model_local_path = download_model(embedding_model=sent_embedding_model, models_path=models_path)
models_info = {
"model": sent_embedding_model,
"model_path": embedding_model_local_path,
sent_embedding_model: embedding_model_local_path,
}

with open(models_info_path, "w", encoding="utf-8") as m_info_f:
json.dump(models_info, m_info_f)

else:
if os.path.exists(models_info_path):
logger.info("Models already exists.")
logger.info(models_info_path)
with open(models_info_path, "r", encoding="utf-8") as m_info_f:
models_info = json.load(m_info_f)

return models_info
return embedding_model_local_path
if os.path.exists(models_info_path):
with open(models_info_path, "r", encoding="utf-8") as m_info_f:
models_info_dict = json.load(m_info_f)
if sent_embedding_model not in models_info_dict.keys():
embedding_model_local_path = download_model(embedding_model=sent_embedding_model, models_path=models_path)
models_info_dict[sent_embedding_model] = embedding_model_local_path
with open(models_info_path, "w", encoding="utf-8") as m_info_f:
json.dump(models_info_dict, m_info_f)
return embedding_model_local_path

logger.info("Model is already available.")
return models_info_dict[sent_embedding_model]

0 comments on commit 491843e

Please sign in to comment.