Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov committed Feb 5, 2024
1 parent 0c331dc commit 0a186ea
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import TypedDict
"""Functions for generating titles and categories for clusters of documents."""

from typing import Optional

import modal
from pydantic import BaseModel

_TOP_K_CENTRAL_DOCS = 7

Expand All @@ -11,7 +14,10 @@
"snippet's language is different than English, mention it in the title. "
)
EXAMPLE_MATH_SNIPPETS = [
'Explain each computation step in the evaluation of 90504690 / 37364. Exclude words; show only the math.',
(
'Explain each computation step in the evaluation of 90504690 / 37364. Exclude words; show only '
'the math.'
),
'What does 3-9030914617332 yield? Only respond with math and no words.',
'Provide a step-by-step calculation for 224 * 276429. Exclude words; show only the math.',
]
Expand All @@ -34,9 +40,8 @@ def get_titling_snippet(text: str) -> str:
EXAMPLE_TITLE = 'Title: Mathematical Calculations'

CATEGORY_SYSTEM_PROMPT = (
'You are a world-class category generator. Generate a short category (two words maximum) for the '
'provided titles. For example, given the titles "Translating English to Polish" '
'and "Translating Korean to English", generate "Category: Translation"'
'You are a world-class category generator. Generate a short category name (one or two words '
'long) for the provided titles. Do not use parentheses and do not generate alternative names.'
)
CATEGORY_EXAMPLE_TITLES = '\n'.join(
[
Expand All @@ -49,17 +54,40 @@ def get_titling_snippet(text: str) -> str:
EXAMPLE_CATEGORY = 'Category: Mathematics'


class Message(TypedDict):
class Message(BaseModel):
"""Message in a conversation."""

role: str
content: str


class SamplingParams(BaseModel):
"""Sampling parameters for the mistral model."""

temperature: float = 0.0
top_p: float = 1.0
max_tokens: int = 50
stop: Optional[str] = None
spaces_between_special_tokens: bool = False


class MistralRequest(BaseModel):
"""Request to embed a list of documents."""

chats: list[list[Message]]
sampling_params: SamplingParams = SamplingParams()


class MistralResponse(BaseModel):
"""Response from the Mistral model."""

outputs: list[str]


def generate_category_mistral(batch_titles: list[list[tuple[str, float]]]) -> list[str]:
"""Summarize a group of titles into a category."""
remote_fn = modal.Function.lookup('mistral-7b', 'Model.generate').remote
request: list[list[Message]] = []
request = MistralRequest(chats=[], sampling_params=SamplingParams(stop='\n'))
for ranked_titles in batch_titles:
# Get the top 5 titles.
titles = [title for title, _ in ranked_titles[:_TOP_K_CENTRAL_DOCS]]
Expand All @@ -69,13 +97,14 @@ def generate_category_mistral(batch_titles: list[list[tuple[str, float]]]) -> li
Message(role='assistant', content=EXAMPLE_CATEGORY),
Message(role='user', content=snippets),
]
request.append(messages)
request.chats.append(messages)

# TODO(smilkov): Add retry logic.
def request_with_retries() -> list[str]:
titles = remote_fn(request)
response_dict = remote_fn(request.model_dump())
response = MistralResponse.model_validate(response_dict)
result: list[str] = []
for title in titles:
for title in response.outputs:
title = title.strip()
if title.lower().startswith('category: '):
title = title[10:]
Expand All @@ -88,7 +117,7 @@ def request_with_retries() -> list[str]:
def generate_title_mistral(batch_docs: list[list[tuple[str, float]]]) -> list[str]:
"""Summarize a group of requests in a title of at most 5 words."""
remote_fn = modal.Function.lookup('mistral-7b', 'Model.generate').remote
request: list[list[Message]] = []
request = MistralRequest(chats=[], sampling_params=SamplingParams(stop='\n'))
for ranked_docs in batch_docs:
# Get the top 5 documents.
docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_DOCS]]
Expand All @@ -100,13 +129,14 @@ def generate_title_mistral(batch_docs: list[list[tuple[str, float]]]) -> list[st
Message(role='assistant', content=EXAMPLE_TITLE),
Message(role='user', content=snippets),
]
request.append(messages)
request.chats.append(messages)

# TODO(smilkov): Add retry logic.
def request_with_retries() -> list[str]:
titles = remote_fn(request)
response_dict = remote_fn(request.model_dump())
response = MistralResponse.model_validate(response_dict)
result: list[str] = []
for title in titles:
for title in response.outputs:
title = title.strip()
if title.lower().startswith('title: '):
title = title[7:]
Expand Down

0 comments on commit 0a186ea

Please sign in to comment.