Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed litellm prints to tqdm and added execution metrics #45

Merged
merged 16 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/op_examples/Skill:right_index/index
sidjha1 marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
Binary file added examples/op_examples/Skill:right_index/vecs
Binary file not shown.
121 changes: 102 additions & 19 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from lotus.types import SemJoinCascadeArgs

lm = LM(model="gpt-4o-mini")
helper_lm = LM(model="gpt-3.5-turbo")
sidjha1 marked this conversation as resolved.
Show resolved Hide resolved
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")

lotus.settings.configure(lm=lm, rm=rm)
lotus.settings.configure(lm=lm, rm=rm, helper_lm=helper_lm)
data = {
"Course Name": [
"Digital Design and Integrated Circuits",
Expand All @@ -18,22 +19,104 @@
}

skills = [
"Math", "Computer Science", "Management", "Creative Writing", "Data Analysis", "Machine Learning",
"Project Management", "Problem Solving", "Singing", "Critical Thinking", "Public Speaking", "Teamwork",
"Adaptability", "Programming", "Leadership", "Time Management", "Negotiation", "Decision Making", "Networking",
"Painting", "Customer Service", "Marketing", "Graphic Design", "Nursery", "SEO", "Content Creation",
"Video Editing", "Sales", "Financial Analysis", "Accounting", "Event Planning", "Foreign Languages",
"Software Development", "Cybersecurity", "Social Media Management", "Photography", "Writing & Editing",
"Technical Support", "Database Management", "Web Development", "Business Strategy", "Operations Management",
"UI/UX Design", "Reinforcement Learning", "Data Visualization", "Product Management", "Cloud Computing",
"Agile Methodology", "Blockchain", "IT Support", "Legal Research", "Supply Chain Management", "Copywriting",
"Human Resources", "Quality Assurance", "Medical Research", "Healthcare Management", "Sports Coaching",
"Editing & Proofreading", "Legal Writing", "Human Anatomy", "Chemistry", "Physics", "Biology", "Psychology",
"Sociology", "Anthropology", "Political Science", "Public Relations", "Fashion Design", "Interior Design",
"Automotive Repair", "Plumbing", "Carpentry", "Electrical Work", "Welding", "Electronics", "Hardware Engineering",
"Circuit Design", "Robotics", "Environmental Science", "Marine Biology", "Urban Planning", "Geography",
"Agricultural Science", "Animal Care", "Veterinary Science", "Zoology", "Ecology", "Botany", "Landscape Design",
"Baking & Pastry", "Culinary Arts", "Bartending", "Nutrition", "Dietary Planning", "Physical Training", "Yoga",
"Math",
"Computer Science",
"Management",
"Creative Writing",
"Data Analysis",
"Machine Learning",
"Project Management",
"Problem Solving",
"Singing",
"Critical Thinking",
"Public Speaking",
"Teamwork",
"Adaptability",
"Programming",
"Leadership",
"Time Management",
"Negotiation",
"Decision Making",
"Networking",
"Painting",
"Customer Service",
"Marketing",
"Graphic Design",
"Nursery",
"SEO",
"Content Creation",
"Video Editing",
"Sales",
"Financial Analysis",
"Accounting",
"Event Planning",
"Foreign Languages",
"Software Development",
"Cybersecurity",
"Social Media Management",
"Photography",
"Writing & Editing",
"Technical Support",
"Database Management",
"Web Development",
"Business Strategy",
"Operations Management",
"UI/UX Design",
"Reinforcement Learning",
"Data Visualization",
"Product Management",
"Cloud Computing",
"Agile Methodology",
"Blockchain",
"IT Support",
"Legal Research",
"Supply Chain Management",
"Copywriting",
"Human Resources",
"Quality Assurance",
"Medical Research",
"Healthcare Management",
"Sports Coaching",
"Editing & Proofreading",
"Legal Writing",
"Human Anatomy",
"Chemistry",
"Physics",
"Biology",
"Psychology",
"Sociology",
"Anthropology",
"Political Science",
"Public Relations",
"Fashion Design",
"Interior Design",
"Automotive Repair",
"Plumbing",
"Carpentry",
"Electrical Work",
"Welding",
"Electronics",
"Hardware Engineering",
"Circuit Design",
"Robotics",
"Environmental Science",
"Marine Biology",
"Urban Planning",
"Geography",
"Agricultural Science",
"Animal Care",
"Veterinary Science",
"Zoology",
"Ecology",
"Botany",
"Landscape Design",
"Baking & Pastry",
"Culinary Arts",
"Bartending",
"Nutrition",
"Dietary Planning",
"Physical Training",
"Yoga",
]
data2 = pd.DataFrame({"Skill": skills})

Expand All @@ -42,7 +125,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target = 0.7, precision_target = 0.7)
cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand All @@ -51,4 +134,4 @@
print(f" Helper resolved {stats['join_resolved_by_helper_model']} LM calls")
print(f"Join cascade used {stats['total_LM_calls']} LM calls in total")
print(f"Naive join would require {df1.shape[0]*df2.shape[0]} LM calls")
print(res)
print(res)
13 changes: 11 additions & 2 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import logging
from typing import Any

import litellm
Expand All @@ -8,11 +9,15 @@
from litellm.utils import token_counter
from openai import OpenAIError
from tokenizers import Tokenizer
from tqdm import tqdm

import lotus
from lotus.cache import Cache
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade

logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)


class LM:
def __init__(
Expand All @@ -24,19 +29,23 @@ def __init__(
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
max_cache_size: int = 1024,
safe_mode: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove safe_mode from LM because its not used here anymore.

**kwargs: dict[str, Any],
):
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer
self.safe_mode = safe_mode
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.cache = Cache(max_cache_size)

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Set top_logprobs if logprobs requested
Expand Down Expand Up @@ -70,7 +79,7 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any
def _process_uncached_messages(self, uncached_data, all_kwargs):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in range(0, len(uncached_data), self.max_batch_size):
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
return uncached_responses
Expand Down
22 changes: 22 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import lotus.models
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput
from lotus.utils import show_safe_mode


def sem_agg(
docs: list[str],
model: lotus.models.LM,
user_instruction: str,
partition_ids: list[int],
safe_mode: bool = False,
) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -76,6 +78,12 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
template_tokens = model.count_tokens(template)
context_tokens = 0
doc_ctr = 1 # num docs in current prompt

if safe_mode:
print(f"Starting tree level {tree_level} aggregation with {len(docs)} docs")
estimated_LM_calls = 0
estimated_costs = 0

for idx in range(len(docs)):
partition_id = partition_ids[idx]
formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
Expand All @@ -98,6 +106,9 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
context_str = formatted_doc
context_tokens = new_tokens
doc_ctr += 1
if safe_mode:
estimated_LM_calls += 1
estimated_costs += model.count_tokens(prompt)
else:
context_str = context_str + formatted_doc
context_tokens += new_tokens
Expand All @@ -108,6 +119,13 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
lotus.logger.debug(f"Prompt added to batch: {prompt}")
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)
if safe_mode:
estimated_LM_calls += 1
estimated_costs += model.count_tokens(prompt)

if safe_mode:
show_safe_mode(estimated_costs, estimated_LM_calls)

lm_output: LMOutput = model(batch)

summaries = lm_output.outputs
Expand All @@ -118,6 +136,8 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
lotus.logger.debug(f"Model outputs from tree level {tree_level}: {summaries}")
tree_level += 1

model.print_total_usage()

return SemanticAggOutput(outputs=summaries)


Expand All @@ -139,6 +159,7 @@ def __call__(
all_cols: bool = False,
suffix: str = "_output",
group_by: list[str] | None = None,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -189,6 +210,7 @@ def __call__(
lotus.settings.lm,
formatted_usr_instr,
partition_ids,
safe_mode=safe_mode,
)

# package answer in a dataframe
Expand Down
12 changes: 12 additions & 0 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
from lotus.utils import show_safe_mode

from .postprocessors import extract_postprocess

Expand All @@ -16,6 +17,7 @@ def sem_extract(
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand All @@ -39,6 +41,12 @@ def sem_extract(
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
inputs.append(prompt)

# check if safe_mode is enabled
if safe_mode:
estimated_cost = sum(model.count_tokens(input) for input in inputs)
estimated_LM_calls = len(docs)
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})

Expand All @@ -47,6 +55,8 @@ def sem_extract(
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")

model.print_total_usage()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do model.print_total_usage everywhere only if we are in safe mode?


return SemanticExtractOutput(**postprocess_output.model_dump())


Expand All @@ -68,6 +78,7 @@ def __call__(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -96,6 +107,7 @@ def __call__(
output_cols=output_cols,
extract_quotes=extract_quotes,
postprocessor=postprocessor,
safe_mode=safe_mode,
)

new_df = self._obj.copy()
Expand Down
Loading
Loading