Skip to content

Commit

Permalink
Merge pull request #5 from krflorian/feature/intent_classification
Browse files Browse the repository at this point in the history
Feature/intent classification
  • Loading branch information
krflorian authored Feb 11, 2024
2 parents 754f9ca + 2d6720e commit debe0b0
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 71 deletions.
8 changes: 1 addition & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#
FROM python:3.11
#RUN pip install poetry==1.4.2

#
WORKDIR /app

#COPY pyproject.toml poetry.lock app.py ./
# COPY pyproject.toml poetry.lock app.py ./
COPY app.py requirements.txt ./
RUN pip install --no-cache-dir --upgrade -r requirements.txt

Expand All @@ -14,10 +13,5 @@ RUN touch README.md
ARG HF_HOME="app/data/.cache"
ENV HF_HOME="app/data/.cache"

#
#RUN poetry config virtualenvs.create false
#RUN poetry install
#RUN pip install .

#
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
17 changes: 0 additions & 17 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import pipeline

from src.objects import Card, Document
from src.vector_db import VectorDB
from src.util import load_config
from src.hallucination import validate_answer
from src.nli import classify_intent

config: dict = load_config(Path("configs/config.yaml"))
vector_model: SentenceTransformer = SentenceTransformer(
Expand All @@ -19,13 +17,6 @@
hallucination_model: CrossEncoder = CrossEncoder(
config.get("halucination_model_name", "vectara/hallucination_evaluation_model")
)
nli_classifier_model = pipeline(
"zero-shot-classification",
model=config.get(
"nli_classifier", "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
),
)


db: dict[str, VectorDB] = {
"card": VectorDB.load(config.get("cards_db_file", None)),
Expand Down Expand Up @@ -150,13 +141,5 @@ async def validate_rag_chunks(
]


@app.post("/nli/")
async def classify_user_intent(
request: NLIClassificationRequest,
) -> NLIClassificationResponse:
intent, score = classify_intent(request.text, classifier=nli_classifier_model)
return NLIClassificationResponse(intent=intent, score=score)


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug", proxy_headers=True)
47 changes: 0 additions & 47 deletions src/nli.py

This file was deleted.

0 comments on commit debe0b0

Please sign in to comment.