Skip to content

Commit

Permalink
Merge pull request #4 from krflorian/feature/intent_classification
Browse files Browse the repository at this point in the history
intent classification
  • Loading branch information
krflorian authored Feb 4, 2024
2 parents cc1aeab + 0df8b9b commit 754f9ca
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WORKDIR /app

#COPY pyproject.toml poetry.lock app.py ./
COPY app.py requirements.txt ./
RUN pip install -r requirements.txt
RUN pip install --no-cache-dir --upgrade -r requirements.txt

COPY ./src /app/src
RUN touch README.md
Expand Down
34 changes: 19 additions & 15 deletions src/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,38 @@ class Intent(Enum):
DECKBUILDING = "deckbuilding"
RULES = "rules"
CONVERSATION = "conversation"
# BAD_INTENTION = "bad_intention"
MALICIOUS = "malicious"


INTENT_MAPPER = {
"rules_question": Intent.RULES,
"rules": Intent.RULES,
"deck_building": Intent.DECKBUILDING,
"specific_cards": Intent.DECKBUILDING,
"card_info": Intent.DECKBUILDING,
"trading_cards": Intent.DECKBUILDING,
"greeting": Intent.CONVERSATION,
"illegal": Intent.MALICIOUS,
"cheating": Intent.MALICIOUS,
"code": Intent.MALICIOUS,
"program": Intent.MALICIOUS,
"script": Intent.MALICIOUS,
}


def classify_intent(text: str, classifier: Pipeline) -> tuple[str, float]:
"""Classify the user intent into one of the classes in Intent."""

hypothesis_template = "You are a chatbot that answers magic the Gathering Questions. The user wants to talk about {}"

intent_mapper = {
"rules_question": "rules",
"rules": "rules",
"deck_building": "deckbuilding",
"specific_cards": "deckbuilding",
"card_info": "deckbuilding",
"trading_cards": "deckbuilding",
"illegal": "conversation", # TODO should be its own class
"cheating": "conversation", # TODO should be its own class
"greeting": "conversation",
}

output = classifier(
text,
list(intent_mapper.keys()),
list(INTENT_MAPPER.keys()),
hypothesis_template=hypothesis_template,
multi_label=False,
)

intent = intent_mapper[output["labels"][0]]
intent = INTENT_MAPPER.get(output["labels"][0], "greeting")
score = output["scores"][0]

logger.info(f"classified intent: {intent} {score:.2f}")
Expand Down

0 comments on commit 754f9ca

Please sign in to comment.