diff --git a/examples/op_examples/join_cascade.py b/examples/op_examples/join_cascade.py index f586738a..f351dccf 100644 --- a/examples/op_examples/join_cascade.py +++ b/examples/op_examples/join_cascade.py @@ -18,22 +18,104 @@ } skills = [ - "Math", "Computer Science", "Management", "Creative Writing", "Data Analysis", "Machine Learning", - "Project Management", "Problem Solving", "Singing", "Critical Thinking", "Public Speaking", "Teamwork", - "Adaptability", "Programming", "Leadership", "Time Management", "Negotiation", "Decision Making", "Networking", - "Painting", "Customer Service", "Marketing", "Graphic Design", "Nursery", "SEO", "Content Creation", - "Video Editing", "Sales", "Financial Analysis", "Accounting", "Event Planning", "Foreign Languages", - "Software Development", "Cybersecurity", "Social Media Management", "Photography", "Writing & Editing", - "Technical Support", "Database Management", "Web Development", "Business Strategy", "Operations Management", - "UI/UX Design", "Reinforcement Learning", "Data Visualization", "Product Management", "Cloud Computing", - "Agile Methodology", "Blockchain", "IT Support", "Legal Research", "Supply Chain Management", "Copywriting", - "Human Resources", "Quality Assurance", "Medical Research", "Healthcare Management", "Sports Coaching", - "Editing & Proofreading", "Legal Writing", "Human Anatomy", "Chemistry", "Physics", "Biology", "Psychology", - "Sociology", "Anthropology", "Political Science", "Public Relations", "Fashion Design", "Interior Design", - "Automotive Repair", "Plumbing", "Carpentry", "Electrical Work", "Welding", "Electronics", "Hardware Engineering", - "Circuit Design", "Robotics", "Environmental Science", "Marine Biology", "Urban Planning", "Geography", - "Agricultural Science", "Animal Care", "Veterinary Science", "Zoology", "Ecology", "Botany", "Landscape Design", - "Baking & Pastry", "Culinary Arts", "Bartending", "Nutrition", "Dietary Planning", "Physical Training", "Yoga", + "Math", + "Computer Science", + "Management", + "Creative Writing", + "Data Analysis", + "Machine Learning", + "Project Management", + "Problem Solving", + "Singing", + "Critical Thinking", + "Public Speaking", + "Teamwork", + "Adaptability", + "Programming", + "Leadership", + "Time Management", + "Negotiation", + "Decision Making", + "Networking", + "Painting", + "Customer Service", + "Marketing", + "Graphic Design", + "Nursery", + "SEO", + "Content Creation", + "Video Editing", + "Sales", + "Financial Analysis", + "Accounting", + "Event Planning", + "Foreign Languages", + "Software Development", + "Cybersecurity", + "Social Media Management", + "Photography", + "Writing & Editing", + "Technical Support", + "Database Management", + "Web Development", + "Business Strategy", + "Operations Management", + "UI/UX Design", + "Reinforcement Learning", + "Data Visualization", + "Product Management", + "Cloud Computing", + "Agile Methodology", + "Blockchain", + "IT Support", + "Legal Research", + "Supply Chain Management", + "Copywriting", + "Human Resources", + "Quality Assurance", + "Medical Research", + "Healthcare Management", + "Sports Coaching", + "Editing & Proofreading", + "Legal Writing", + "Human Anatomy", + "Chemistry", + "Physics", + "Biology", + "Psychology", + "Sociology", + "Anthropology", + "Political Science", + "Public Relations", + "Fashion Design", + "Interior Design", + "Automotive Repair", + "Plumbing", + "Carpentry", + "Electrical Work", + "Welding", + "Electronics", + "Hardware Engineering", + "Circuit Design", + "Robotics", + "Environmental Science", + "Marine Biology", + "Urban Planning", + "Geography", + "Agricultural Science", + "Animal Care", + "Veterinary Science", + "Zoology", + "Ecology", + "Botany", + "Landscape Design", + "Baking & Pastry", + "Culinary Arts", + "Bartending", + "Nutrition", + "Dietary Planning", + "Physical Training", + "Yoga", ] data2 = pd.DataFrame({"Skill": skills}) @@ -42,7 +124,7 @@ df2 = pd.DataFrame(data2) join_instruction = "By taking {Course Name:left} I will learn {Skill:right}" -cascade_args = SemJoinCascadeArgs(recall_target = 0.7, precision_target = 0.7) +cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7) res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True) @@ -51,4 +133,4 @@ print(f" Helper resolved {stats['join_resolved_by_helper_model']} LM calls") print(f"Join cascade used {stats['total_LM_calls']} LM calls in total") print(f"Naive join would require {df1.shape[0]*df2.shape[0]} LM calls") -print(res) \ No newline at end of file +print(res) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 30852ebb..a5fdc570 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,4 +1,5 @@ import hashlib +import logging from typing import Any import litellm @@ -8,11 +9,15 @@ from litellm.utils import token_counter from openai import OpenAIError from tokenizers import Tokenizer +from tqdm import tqdm import lotus from lotus.cache import Cache from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade +logging.getLogger("LiteLLM").setLevel(logging.CRITICAL) +logging.getLogger("httpx").setLevel(logging.CRITICAL) + class LM: def __init__( @@ -36,7 +41,9 @@ def __init__( self.stats: LMStats = LMStats() self.cache = Cache(max_cache_size) - def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: + def __call__( + self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any] + ) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} # Set top_logprobs if logprobs requested @@ -70,7 +77,7 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any def _process_uncached_messages(self, uncached_data, all_kwargs): """Processes uncached messages in batches and returns responses.""" uncached_responses = [] - for i in range(0, len(uncached_data), self.max_batch_size): + for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"): batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]] uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs)) return uncached_responses diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 56a95ff9..b8c9a278 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -12,6 +12,7 @@ def sem_agg( model: lotus.models.LM, user_instruction: str, partition_ids: list[int], + safe_mode: bool = False, ) -> SemanticAggOutput: """ Aggregates multiple documents into a single answer using a model. @@ -60,6 +61,10 @@ def node_doc_formatter(doc: str, ctr: int) -> str: def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: return leaf_doc_formatter(doc, ctr) if tree_level == 0 else node_doc_formatter(doc, ctr) + if safe_mode: + # TODO: implement safe mode + lotus.logger.warning("Safe mode is not implemented yet") + tree_level = 0 summaries: list[str] = [] new_partition_ids: list[int] = [] @@ -76,6 +81,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: template_tokens = model.count_tokens(template) context_tokens = 0 doc_ctr = 1 # num docs in current prompt + for idx in range(len(docs)): partition_id = partition_ids[idx] formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr) @@ -108,6 +114,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) new_partition_ids.append(cur_partition_id) + lm_output: LMOutput = model(batch) summaries = lm_output.outputs @@ -117,6 +124,8 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: docs = summaries lotus.logger.debug(f"Model outputs from tree level {tree_level}: {summaries}") tree_level += 1 + if safe_mode: + model.print_total_usage() return SemanticAggOutput(outputs=summaries) @@ -139,6 +148,7 @@ def __call__( all_cols: bool = False, suffix: str = "_output", group_by: list[str] | None = None, + safe_mode: bool = False, ) -> pd.DataFrame: """ Applies semantic aggregation over a dataframe. @@ -189,6 +199,7 @@ def __call__( lotus.settings.lm, formatted_usr_instr, partition_ids, + safe_mode=safe_mode, ) # package answer in a dataframe diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 515e56cb..93c4c9ba 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -6,6 +6,7 @@ from lotus.models import LM from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.utils import show_safe_mode from .postprocessors import extract_postprocess @@ -16,6 +17,7 @@ def sem_extract( output_cols: dict[str, str | None], extract_quotes: bool = False, postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, + safe_mode: bool = False, ) -> SemanticExtractOutput: """ Extracts attributes and values from a list of documents using a model. @@ -39,6 +41,12 @@ def sem_extract( lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) + # check if safe_mode is enabled + if safe_mode: + estimated_cost = sum(model.count_tokens(input) for input in inputs) + estimated_LM_calls = len(docs) + show_safe_mode(estimated_cost, estimated_LM_calls) + # call model lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}) @@ -46,6 +54,8 @@ def sem_extract( postprocess_output = postprocessor(lm_output.outputs) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") + if safe_mode: + model.print_total_usage() return SemanticExtractOutput(**postprocess_output.model_dump()) @@ -68,6 +78,7 @@ def __call__( extract_quotes: bool = False, postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, return_raw_outputs: bool = False, + safe_mode: bool = False, ) -> pd.DataFrame: """ Extracts the attributes and values of a dataframe. @@ -96,6 +107,7 @@ def __call__( output_cols=output_cols, extract_quotes=extract_quotes, postprocessor=postprocessor, + safe_mode=safe_mode, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index bad0d5e3..372df5dc 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -7,6 +7,7 @@ import lotus from lotus.templates import task_instructions from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput +from lotus.utils import show_safe_mode from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -22,6 +23,7 @@ def sem_filter( cot_reasoning: list[str] | None = None, strategy: str | None = None, logprobs: bool = False, + safe_mode: bool = False, ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -47,6 +49,12 @@ def sem_filter( lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) kwargs: dict[str, Any] = {"logprobs": logprobs} + + if safe_mode: + estimated_total_calls = len(docs) + estimated_total_cost = sum(model.count_tokens(input) for input in inputs) + show_safe_mode(estimated_total_cost, estimated_total_calls) + lm_output: LMOutput = model(inputs, **kwargs) postprocess_output = filter_postprocess( @@ -56,6 +64,9 @@ def sem_filter( lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") + if safe_mode: + model.print_total_usage() + return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None) @@ -88,6 +99,7 @@ def learn_filter_cascade_thresholds( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=False, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -137,6 +149,7 @@ def __call__( precision_target: float | None = None, failure_probability: float | None = None, return_stats: bool = False, + safe_mode: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -221,6 +234,7 @@ def __call__( cot_reasoning=helper_cot_reasoning, logprobs=True, strategy=helper_strategy, + safe_mode=safe_mode, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs formatted_helper_logprobs: LogprobsForFilterCascade = ( @@ -302,6 +316,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=safe_mode, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -322,6 +337,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=safe_mode, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 5248998f..a9aae5f7 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -5,6 +5,7 @@ import lotus from lotus.templates import task_instructions from lotus.types import SemanticJoinOutput, SemJoinCascadeArgs +from lotus.utils import show_safe_mode from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds from .sem_filter import sem_filter @@ -24,6 +25,7 @@ def sem_join( cot_reasoning: list[str] | None = None, default: bool = True, strategy: str | None = None, + safe_mode: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a model. @@ -53,6 +55,19 @@ def sem_join( left_multimodal_data = task_instructions.df2multimodal_info(l1.to_frame(col1_label), [col1_label]) right_multimodal_data = task_instructions.df2multimodal_info(l2.to_frame(col2_label), [col2_label]) + + if safe_mode: + sample_docs = task_instructions.merge_multimodal_info([left_multimodal_data[0]], right_multimodal_data) + estimated_tokens_per_call = model.count_tokens( + lotus.templates.task_instructions.filter_formatter( + sample_docs[0], user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy + ) + ) + estimated_total_calls = len(l1) * len(l2) + estimated_total_cost = estimated_tokens_per_call * estimated_total_calls + print("Sem_Join:") + show_safe_mode(estimated_total_cost, estimated_total_calls) + # for i1 in enumerate(l1): for id1, i1 in zip(ids1, left_multimodal_data): # perform llm filter @@ -113,6 +128,7 @@ def sem_join_cascade( cot_reasoning: list[str] | None = None, default: bool = True, strategy: str | None = None, + safe_mode: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a cascade helper model and a large model. @@ -182,6 +198,10 @@ def sem_join_cascade( num_helper = len(helper_high_conf) num_large = len(helper_low_conf) + if safe_mode: + # TODO: implement safe mode + lotus.logger.warning("Safe mode is not implemented yet.") + # Accept helper results with high confidence join_results = [(row["_left_id"], row["_right_id"], None) for _, row in helper_high_conf.iterrows()] @@ -538,6 +558,7 @@ def __call__( default: bool = True, cascade_args: SemJoinCascadeArgs | None = None, return_stats: bool = False, + safe_mode: bool = False, ) -> pd.DataFrame: """ Applies semantic join over a dataframe. @@ -647,6 +668,7 @@ def __call__( cot_reasoning=cot_reasoning, default=default, strategy=strategy, + safe_mode=safe_mode, ) else: output = sem_join( @@ -663,6 +685,7 @@ def __call__( cot_reasoning=cot_reasoning, default=default, strategy=strategy, + safe_mode=safe_mode, ) join_results = output.join_results all_raw_outputs = output.all_raw_outputs diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 99d8a84d..0526e221 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -5,6 +5,7 @@ import lotus from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.utils import show_safe_mode from .postprocessors import map_postprocess @@ -18,6 +19,7 @@ def sem_map( examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, + safe_mode: bool = False, ) -> SemanticMapOutput: """ Maps a list of documents to a list of outputs using a model. @@ -44,6 +46,12 @@ def sem_map( lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) + # check if safe_mode is enabled + if safe_mode: + estimated_cost = sum(model.count_tokens(input) for input in inputs) + estimated_LM_calls = len(docs) + show_safe_mode(estimated_cost, estimated_LM_calls) + # call model lm_output: LMOutput = model(inputs) @@ -52,6 +60,8 @@ def sem_map( lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") + if safe_mode: + model.print_total_usage() return SemanticMapOutput(**postprocess_output.model_dump()) @@ -78,6 +88,7 @@ def __call__( suffix: str = "_map", examples: pd.DataFrame | None = None, strategy: str | None = None, + safe_mode: bool = False, ) -> pd.DataFrame: """ Applies semantic map over a dataframe. @@ -125,6 +136,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=safe_mode, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 0af1b475..461e06a1 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -8,6 +8,7 @@ import lotus from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticTopKOutput +from lotus.utils import show_safe_mode def get_match_prompt_binary( @@ -121,6 +122,7 @@ def llm_naive_sort( docs: list[dict[str, Any]], user_instruction: str, strategy: str | None = None, + safe_mode: bool = False, ) -> SemanticTopKOutput: """ Sorts the documents using a naive quadratic method. @@ -140,6 +142,8 @@ def llm_naive_sort( llm_calls = len(pairs) comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy) + if safe_mode: + show_safe_mode(tokens, llm_calls) votes = [0] * N idx = 0 for i in range(N): @@ -163,6 +167,7 @@ def llm_quicksort( embedding: bool = False, strategy: str | None = None, cascade_threshold: float | None = None, + safe_mode: bool = False, ) -> SemanticTopKOutput: """ Sorts the documents using quicksort. @@ -180,6 +185,13 @@ def llm_quicksort( stats = {} stats["total_tokens"] = 0 stats["total_llm_calls"] = 0 + if safe_mode: + sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy) + estimated_quickselect_calls = 2 * K + estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs)) + estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls + estimated_total_tokens = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls + show_safe_mode(estimated_total_tokens, estimated_total_calls) if cascade_threshold is not None: stats["total_small_tokens"] = 0 @@ -275,6 +287,7 @@ def llm_heapsort( user_instruction: str, K: int, strategy: str | None = None, + safe_mode: bool = False, ) -> SemanticTopKOutput: """ Sorts the documents using a heap. @@ -287,11 +300,21 @@ def llm_heapsort( Returns: SemanticTopKOutput: The indexes of the top k documents and stats. """ + + if safe_mode: + sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy) + estimated_heap_construction_calls = len(docs) * np.log(len(docs)) + estimated_top_k_extraction_calls = K * np.log(len(docs)) + estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls + estimated_total_cost = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls + show_safe_mode(estimated_total_cost, estimated_total_calls) + HeapDoc.num_calls = 0 HeapDoc.total_tokens = 0 HeapDoc.strategy = strategy N = len(docs) heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] + heap = heapq.nsmallest(K, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] @@ -320,6 +343,7 @@ def __call__( group_by: list[str] | None = None, cascade_threshold: float | None = None, return_stats: bool = False, + safe_mode: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Sorts the DataFrame based on the user instruction and returns the top K rows. @@ -392,14 +416,22 @@ def __call__( embedding=method == "quick-sem", strategy=strategy, cascade_threshold=cascade_threshold, + safe_mode=safe_mode, ) elif method == "heap": - output = llm_heapsort(multimodal_data, formatted_usr_instr, K, strategy=strategy) + output = llm_heapsort( + multimodal_data, + formatted_usr_instr, + K, + strategy=strategy, + safe_mode=safe_mode, + ) elif method == "naive": output = llm_naive_sort( multimodal_data, formatted_usr_instr, strategy=strategy, + safe_mode=safe_mode, ) else: raise ValueError(f"Method {method} not recognized") diff --git a/lotus/utils.py b/lotus/utils.py index 1f86347c..fa461e1b 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -1,4 +1,5 @@ import base64 +import time from io import BytesIO from typing import Callable @@ -106,3 +107,17 @@ def fetch_image(image: str | np.ndarray | Image.Image | None, image_type: str = return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") return image_obj + + +def show_safe_mode(estimated_cost, estimated_LM_calls): + print(f"Estimated cost: {estimated_cost} tokens") + print(f"Estimated LM calls: {estimated_LM_calls}") + try: + for i in range(5, 0, -1): + print(f"Proceeding execution in {i} seconds... Press CTRL+C to cancel", end="\r") + time.sleep(1) + print(" " * 60, end="\r") + print("\n") + except KeyboardInterrupt: + print("\nExecution cancelled by user") + exit(0)