Skip to content

Commit

Permalink
Changed litellm prints to tqdm and added execution metrics (#45)
Browse files Browse the repository at this point in the history
Replaced litellm prints with tqdm progress bar and added the execution
metrics (cost and tokens)

Safe_mode implementation:
sem_extract - completed
sem_map - completed
sem_agg - completed
sem_topk - completed
sem_join - completed
sem_filter - completed
  • Loading branch information
StanChan03 authored Dec 7, 2024
1 parent c7ed69a commit 99f96b2
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 21 deletions.
118 changes: 100 additions & 18 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,104 @@
}

skills = [
"Math", "Computer Science", "Management", "Creative Writing", "Data Analysis", "Machine Learning",
"Project Management", "Problem Solving", "Singing", "Critical Thinking", "Public Speaking", "Teamwork",
"Adaptability", "Programming", "Leadership", "Time Management", "Negotiation", "Decision Making", "Networking",
"Painting", "Customer Service", "Marketing", "Graphic Design", "Nursery", "SEO", "Content Creation",
"Video Editing", "Sales", "Financial Analysis", "Accounting", "Event Planning", "Foreign Languages",
"Software Development", "Cybersecurity", "Social Media Management", "Photography", "Writing & Editing",
"Technical Support", "Database Management", "Web Development", "Business Strategy", "Operations Management",
"UI/UX Design", "Reinforcement Learning", "Data Visualization", "Product Management", "Cloud Computing",
"Agile Methodology", "Blockchain", "IT Support", "Legal Research", "Supply Chain Management", "Copywriting",
"Human Resources", "Quality Assurance", "Medical Research", "Healthcare Management", "Sports Coaching",
"Editing & Proofreading", "Legal Writing", "Human Anatomy", "Chemistry", "Physics", "Biology", "Psychology",
"Sociology", "Anthropology", "Political Science", "Public Relations", "Fashion Design", "Interior Design",
"Automotive Repair", "Plumbing", "Carpentry", "Electrical Work", "Welding", "Electronics", "Hardware Engineering",
"Circuit Design", "Robotics", "Environmental Science", "Marine Biology", "Urban Planning", "Geography",
"Agricultural Science", "Animal Care", "Veterinary Science", "Zoology", "Ecology", "Botany", "Landscape Design",
"Baking & Pastry", "Culinary Arts", "Bartending", "Nutrition", "Dietary Planning", "Physical Training", "Yoga",
"Math",
"Computer Science",
"Management",
"Creative Writing",
"Data Analysis",
"Machine Learning",
"Project Management",
"Problem Solving",
"Singing",
"Critical Thinking",
"Public Speaking",
"Teamwork",
"Adaptability",
"Programming",
"Leadership",
"Time Management",
"Negotiation",
"Decision Making",
"Networking",
"Painting",
"Customer Service",
"Marketing",
"Graphic Design",
"Nursery",
"SEO",
"Content Creation",
"Video Editing",
"Sales",
"Financial Analysis",
"Accounting",
"Event Planning",
"Foreign Languages",
"Software Development",
"Cybersecurity",
"Social Media Management",
"Photography",
"Writing & Editing",
"Technical Support",
"Database Management",
"Web Development",
"Business Strategy",
"Operations Management",
"UI/UX Design",
"Reinforcement Learning",
"Data Visualization",
"Product Management",
"Cloud Computing",
"Agile Methodology",
"Blockchain",
"IT Support",
"Legal Research",
"Supply Chain Management",
"Copywriting",
"Human Resources",
"Quality Assurance",
"Medical Research",
"Healthcare Management",
"Sports Coaching",
"Editing & Proofreading",
"Legal Writing",
"Human Anatomy",
"Chemistry",
"Physics",
"Biology",
"Psychology",
"Sociology",
"Anthropology",
"Political Science",
"Public Relations",
"Fashion Design",
"Interior Design",
"Automotive Repair",
"Plumbing",
"Carpentry",
"Electrical Work",
"Welding",
"Electronics",
"Hardware Engineering",
"Circuit Design",
"Robotics",
"Environmental Science",
"Marine Biology",
"Urban Planning",
"Geography",
"Agricultural Science",
"Animal Care",
"Veterinary Science",
"Zoology",
"Ecology",
"Botany",
"Landscape Design",
"Baking & Pastry",
"Culinary Arts",
"Bartending",
"Nutrition",
"Dietary Planning",
"Physical Training",
"Yoga",
]
data2 = pd.DataFrame({"Skill": skills})

Expand All @@ -42,7 +124,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target = 0.7, precision_target = 0.7)
cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand All @@ -51,4 +133,4 @@
print(f" Helper resolved {stats['join_resolved_by_helper_model']} LM calls")
print(f"Join cascade used {stats['total_LM_calls']} LM calls in total")
print(f"Naive join would require {df1.shape[0]*df2.shape[0]} LM calls")
print(res)
print(res)
11 changes: 9 additions & 2 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import logging
from typing import Any

import litellm
Expand All @@ -8,11 +9,15 @@
from litellm.utils import token_counter
from openai import OpenAIError
from tokenizers import Tokenizer
from tqdm import tqdm

import lotus
from lotus.cache import Cache
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade

logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)


class LM:
def __init__(
Expand All @@ -36,7 +41,9 @@ def __init__(
self.stats: LMStats = LMStats()
self.cache = Cache(max_cache_size)

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Set top_logprobs if logprobs requested
Expand Down Expand Up @@ -70,7 +77,7 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any
def _process_uncached_messages(self, uncached_data, all_kwargs):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in range(0, len(uncached_data), self.max_batch_size):
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
return uncached_responses
Expand Down
11 changes: 11 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def sem_agg(
model: lotus.models.LM,
user_instruction: str,
partition_ids: list[int],
safe_mode: bool = False,
) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -60,6 +61,10 @@ def node_doc_formatter(doc: str, ctr: int) -> str:
def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
return leaf_doc_formatter(doc, ctr) if tree_level == 0 else node_doc_formatter(doc, ctr)

if safe_mode:
# TODO: implement safe mode
lotus.logger.warning("Safe mode is not implemented yet")

tree_level = 0
summaries: list[str] = []
new_partition_ids: list[int] = []
Expand All @@ -76,6 +81,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
template_tokens = model.count_tokens(template)
context_tokens = 0
doc_ctr = 1 # num docs in current prompt

for idx in range(len(docs)):
partition_id = partition_ids[idx]
formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
Expand Down Expand Up @@ -108,6 +114,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
lotus.logger.debug(f"Prompt added to batch: {prompt}")
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)

lm_output: LMOutput = model(batch)

summaries = lm_output.outputs
Expand All @@ -117,6 +124,8 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
docs = summaries
lotus.logger.debug(f"Model outputs from tree level {tree_level}: {summaries}")
tree_level += 1
if safe_mode:
model.print_total_usage()

return SemanticAggOutput(outputs=summaries)

Expand All @@ -139,6 +148,7 @@ def __call__(
all_cols: bool = False,
suffix: str = "_output",
group_by: list[str] | None = None,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -189,6 +199,7 @@ def __call__(
lotus.settings.lm,
formatted_usr_instr,
partition_ids,
safe_mode=safe_mode,
)

# package answer in a dataframe
Expand Down
12 changes: 12 additions & 0 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
from lotus.utils import show_safe_mode

from .postprocessors import extract_postprocess

Expand All @@ -16,6 +17,7 @@ def sem_extract(
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand All @@ -39,13 +41,21 @@ def sem_extract(
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
inputs.append(prompt)

# check if safe_mode is enabled
if safe_mode:
estimated_cost = sum(model.count_tokens(input) for input in inputs)
estimated_LM_calls = len(docs)
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})

# post process results
postprocess_output = postprocessor(lm_output.outputs)
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
if safe_mode:
model.print_total_usage()

return SemanticExtractOutput(**postprocess_output.model_dump())

Expand All @@ -68,6 +78,7 @@ def __call__(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -96,6 +107,7 @@ def __call__(
output_cols=output_cols,
extract_quotes=extract_quotes,
postprocessor=postprocessor,
safe_mode=safe_mode,
)

new_df = self._obj.copy()
Expand Down
16 changes: 16 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lotus
from lotus.templates import task_instructions
from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode

from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds
from .postprocessors import filter_postprocess
Expand All @@ -22,6 +23,7 @@ def sem_filter(
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
logprobs: bool = False,
safe_mode: bool = False,
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand All @@ -47,6 +49,12 @@ def sem_filter(
lotus.logger.debug(f"input to model: {prompt}")
inputs.append(prompt)
kwargs: dict[str, Any] = {"logprobs": logprobs}

if safe_mode:
estimated_total_calls = len(docs)
estimated_total_cost = sum(model.count_tokens(input) for input in inputs)
show_safe_mode(estimated_total_cost, estimated_total_calls)

lm_output: LMOutput = model(inputs, **kwargs)

postprocess_output = filter_postprocess(
Expand All @@ -56,6 +64,9 @@ def sem_filter(
lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")

if safe_mode:
model.print_total_usage()

return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None)


Expand Down Expand Up @@ -88,6 +99,7 @@ def learn_filter_cascade_thresholds(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=False,
).outputs

best_combination, _ = learn_cascade_thresholds(
Expand Down Expand Up @@ -137,6 +149,7 @@ def __call__(
precision_target: float | None = None,
failure_probability: float | None = None,
return_stats: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]:
"""
Applies semantic filter over a dataframe.
Expand Down Expand Up @@ -221,6 +234,7 @@ def __call__(
cot_reasoning=helper_cot_reasoning,
logprobs=True,
strategy=helper_strategy,
safe_mode=safe_mode,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
formatted_helper_logprobs: LogprobsForFilterCascade = (
Expand Down Expand Up @@ -302,6 +316,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
)

for idx, large_idx in enumerate(low_conf_idxs):
Expand All @@ -322,6 +337,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
Loading

0 comments on commit 99f96b2

Please sign in to comment.