From d2b8cba6dd15e4fd69848e66b416c1bfaded255f Mon Sep 17 00:00:00 2001 From: Parth Asawa Date: Mon, 28 Oct 2024 10:50:58 -0700 Subject: [PATCH 1/3] Learned cascade thresholds --- .github/tests/lm_tests.py | 76 +++++++++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 14 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 306747cd..af23d1c7 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -21,12 +21,12 @@ def test_filter_operation(setup_models): lotus.settings.configure(lm=gpt_4o_mini) # Test filter operation on an easy dataframe - data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]} + data = {"Text": ["I am really excited to go to class today!", "I am very sad"]} df = pd.DataFrame(data) user_instruction = "{Text} is a positive sentiment" filtered_df = df.sem_filter(user_instruction) - expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]}) + expected_df = pd.DataFrame({"Text": ["I am really excited to go to class today!"]}) assert filtered_df.equals(expected_df) @@ -34,22 +34,70 @@ def test_filter_cascade(setup_models): gpt_4o_mini, gpt_4o = setup_models lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) - data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]} + data = { + "Text": [ + # Positive examples + "I am really excited to go to class today!", + "Today is going to be an amazing day!", + "I absolutely love the new project I am working on.", + "Feeling so grateful for everything I have.", + "I can't wait to see my friends this weekend!", + "The weather is beautiful, and I feel fantastic.", + "Just received some great news about my promotion!", + "I'm so happy to have such supportive colleagues.", + "I'm thrilled to be learning something new every day.", + "Life is really good right now, and I feel blessed.", + "I am proud of all the progress I've made this year.", + "Today was productive, and I feel accomplished.", + "I’m really enjoying my workout routine lately!", + "Got a compliment from my manager today, feeling awesome!", + "Looking forward to spending time with family tonight.", + "Just finished a great book and feel inspired!", + "Had a lovely meal with friends, life is good!", + "Everything is going as planned, couldn't be happier.", + "Feeling super motivated and ready to take on challenges!", + "I appreciate all the small things that bring me joy.", + + # Negative examples + "I am very sad.", + "Today has been really tough; I feel exhausted.", + "I'm feeling pretty down about how things are going.", + "I’m overwhelmed with all these challenges.", + "It’s hard to stay positive when things keep going wrong.", + "I feel so alone and unappreciated.", + "My energy is low, and nothing seems to cheer me up.", + "Feeling anxious about everything lately.", + "I’m disappointed with the way my project turned out.", + "Today has been one of those days where everything goes wrong.", + "Life feels really overwhelming right now.", + "I can't seem to find any motivation these days.", + "I’m worried about the future and what it holds.", + "It's been a stressful day, and I feel mentally drained.", + "I feel like I'm falling behind everyone else.", + "Just can't seem to catch a break recently.", + "I’m really struggling to keep up with all my responsibilities.", + "Had an argument with a close friend, feeling hurt.", + "I don’t feel supported by my team at work.", + "Life has been tough lately, and I’m feeling down.", + ] + } + df = pd.DataFrame(data) user_instruction = "{Text} is a positive sentiment" # All filters resolved by the helper model - filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=0, return_stats=True) - assert stats["filters_resolved_by_large_model"] == 0, stats - assert stats["filters_resolved_by_helper_model"] == 2, stats - expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]}) - assert filtered_df.equals(expected_df) - - # All filters resolved by the large model - filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=1.01, return_stats=True) - assert stats["filters_resolved_by_large_model"] == 2, stats - assert stats["filters_resolved_by_helper_model"] == 0, stats - assert filtered_df.equals(expected_df) + filtered_df, stats = df.sem_filter( + user_instruction=user_instruction, + learn_cascade_threshold_sample_percentage=0.5, + recall_target=0.9, + precision_target=0.9, + failure_probability=0.2, + return_stats=True, + ) + + assert "I am really excited to go to class today!" in filtered_df["Text"].values + assert "I am very sad" not in filtered_df["Text"].values + assert stats["filters_resolved_by_helper_model"] > 0, stats def test_top_k(setup_models): From 3c2b3fdca2f7ea9dc4df32359beaafcc95b2662c Mon Sep 17 00:00:00 2001 From: Parth Asawa Date: Mon, 28 Oct 2024 10:51:27 -0700 Subject: [PATCH 2/3] Core code --- examples/op_examples/filter_cascade.py | 108 ++++++++++++++++++++- lotus/models/openai_model.py | 39 +++++++- lotus/sem_ops/cascade_utils.py | 110 +++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 126 +++++++++++++++++++++---- 4 files changed, 364 insertions(+), 19 deletions(-) create mode 100644 lotus/sem_ops/cascade_utils.py diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index f58a1303..5af900b2 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -13,10 +13,116 @@ "Optimization Methods in Engineering", "Digital Design and Integrated Circuits", "Computer Security", + "Data Structures and Algorithms", + "Machine Learning", + "Artificial Intelligence", + "Natural Language Processing", + "Introduction to Robotics", + "Control Systems", + "Linear Algebra and Differential Equations", + "Database Systems", + "Cloud Computing", + "Software Engineering", + "Operating Systems", + "Discrete Mathematics", + "Numerical Methods", + "Wireless Communication Systems", + "Embedded Systems", + "Advanced Computer Architecture", + "Graph Theory", + "Cryptography and Network Security", + "Big Data Analytics", + "Deep Learning", + "Organic Chemistry", + "Molecular Biology", + "Environmental Science", + "Genetics and Evolution", + "Human Physiology", + "Introduction to Anthropology", + "Cultural Studies", + "Political Theory", + "Macroeconomics", + "Microeconomics", + "Introduction to Sociology", + "Developmental Psychology", + "Cognitive Science", + "Introduction to Philosophy", + "Ethics and Moral Philosophy", + "History of Western Civilization", + "Art History: Renaissance to Modern", + "World Literature", + "Introduction to Journalism", + "Public Speaking and Communication", + "Creative Writing", + "Music Theory", + "Introduction to Theater", + "Film Studies", + "Environmental Policy and Law", + "Sustainability and Renewable Energy", + "Urban Planning and Design", + "International Relations", + "Marketing Principles", + "Organizational Behavior", + "Financial Accounting", + "Corporate Finance", + "Business Law", + "Supply Chain Management", + "Operations Research", + "Entrepreneurship and Innovation", + "Introduction to Psychology", + "Health Economics", + "Biostatistics", + "Social Work Practice", + "Public Health Policy", + "Environmental Ethics", + "History of Political Thought", + "Quantitative Research Methods", + "Comparative Politics", + "Urban Economics", + "Behavioral Economics", + "Sociology of Education", + "Social Psychology", + "Gender Studies", + "Media and Communication Studies", + "Advertising and Brand Strategy", + "Sports Management", + "Introduction to Archaeology", + "Ecology and Conservation Biology", + "Marine Biology", + "Geology and Earth Science", + "Astronomy and Astrophysics", + "Introduction to Meteorology", + "Introduction to Oceanography", + "Quantum Physics", + "Thermodynamics", + "Fluid Mechanics", + "Solid State Physics", + "Classical Mechanics", + "Introduction to Civil Engineering", + "Material Science and Engineering", + "Structural Engineering", + "Environmental Engineering", + "Energy Systems Engineering", + "Aerodynamics", + "Heat Transfer", + "Renewable Energy Systems", + "Transportation Engineering", + "Water Resources Management", + "Principles of Accounting", + "Project Management", + "International Business", + "Business Analytics", ] } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df, stats = df.sem_filter(user_instruction, cascade_threshold=0.95, return_stats=True) +df, stats = df.sem_filter( + user_instruction=user_instruction, + learn_cascade_threshold_sample_percentage=0.5, + recall_target=0.9, + precision_target=0.9, + failure_probability=0.2, + return_stats=True, +) print(df) print(stats) diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py index f950b97f..57fb20eb 100644 --- a/lotus/models/openai_model.py +++ b/lotus/models/openai_model.py @@ -88,7 +88,7 @@ def handle_chat_request( then a list of logprobs is also returned. """ if kwargs.get("logprobs", False): - kwargs["top_logprobs"] = 1 + kwargs["top_logprobs"] = 10 kwargs = {**self.kwargs, **kwargs} kwargs["messages"] = messages @@ -127,6 +127,8 @@ def handle_completion_request( kwargs = {**self.kwargs, **kwargs} kwargs["prompt"] = prompt + if kwargs.get("logprobs", False): + kwargs["logprobs"] = 10 response = self.completion_request(**kwargs) choices = response["choices"] @@ -254,6 +256,41 @@ def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], all_confidences.append(confidences) return all_tokens, all_confidences + + def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: + all_tokens = [] + all_confidences = [] + all_true_probs = [] + for idx in range(len(logprobs)): + if self.provider == "vllm": + tokens = logprobs[idx]["tokens"] + confidences = np.exp(logprobs[idx]["token_logprobs"]) + top_logprobs = logprobs[idx]["top_logprobs"][0] + if 'True' in top_logprobs and 'False' in top_logprobs: + true_prob = np.exp(top_logprobs['True']) + false_prob = np.exp(top_logprobs['False']) + all_true_probs.append(true_prob / (true_prob + false_prob)) + else: + all_true_probs.append(1 if 'True' in top_logprobs else 0) + + elif self.provider == "openai": + content = logprobs[idx]["content"] + tokens = [content[t_idx]["token"] for t_idx in range(len(content))] + confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) + top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]} + + true_prob, false_prob = 0, 0 + if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs: + true_prob = np.exp(top_logprobs['True']) + false_prob = np.exp(top_logprobs['False']) + all_true_probs.append(true_prob / (true_prob + false_prob)) + else: + all_true_probs.append(1 if 'True' in top_logprobs else 0) + + all_tokens.append(tokens) + all_confidences.append(confidences) + + return all_tokens, all_confidences, all_true_probs def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: """Send chat request to OpenAI server. diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py new file mode 100644 index 00000000..7e736a44 --- /dev/null +++ b/lotus/sem_ops/cascade_utils.py @@ -0,0 +1,110 @@ +import numpy as np + +import lotus + + +def importance_sampling( + proxy_scores: list[float], + sample_percentage: float, +): + """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" + + w = np.sqrt(proxy_scores) + w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores) + indices = np.arange(len(proxy_scores)) + sample_size = (int) (sample_percentage * len(proxy_scores)) + sample_indices = np.random.choice(indices, sample_size, p=w) + correction_factors = (1/len(proxy_scores)) / w + + return sample_indices, correction_factors + +def calibrate_llm_logprobs(true_probs: list[float]): + """Transforms true probabilities to calibrate LLM proxies.""" + num_quantiles = 50 + quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) + true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) + true_probs = np.clip(true_probs, 0, 1) + return true_probs + +def learn_cascade_thresholds( + proxy_scores: list[float], + oracle_outputs: list[float], + sample_correction_factors: list[float], + recall_target: float, + precision_target: float, + delta: float +): + """Learns cascade thresholds given targets and proxy scores, + oracle outputs over the sample, and correction factors for the + sample.""" + + def UB(mean, std_dev, s, delta): + return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + + def LB(mean, std_dev, s, delta): + return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + + def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] + sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold] + total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs) + recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct + return recall + + def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] + sent_to_oracle = [x for x in sorted_pairs if pos_threshold > x[0] > neg_threshold] + oracle_positive = sum(x[1] for x in sent_to_oracle) + true_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + oracle_positive + predicted_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold) + oracle_positive + precision = true_positives / predicted_positives if predicted_positives > 0 else 0 + return precision + + # Pair helper model probabilities with helper correctness and oracle answer + paired_data = list(zip(proxy_scores, oracle_outputs, sample_correction_factors)) + sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True) + sample_size = len(sorted_pairs) + + best_combination = (1,0) # initial tau_+, tau_- + + # Find tau_negative based on recall + tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target) + best_combination = (best_combination[0], tau_neg_0) + + # Do a statistical correction to get a new target recall + Z1 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] >= best_combination[1]] + Z2 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] < best_combination[1]] + + mean_z1 = np.mean(Z1) if Z1 else 0 + std_z1 = np.std(Z1) if Z1 else 0 + mean_z2 = np.mean(Z2) if Z2 else 0 + std_z2 = np.std(Z2) if Z2 else 0 + + corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2)) + corrected_recall_target = min(1, corrected_recall_target) + tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target) + best_combination = (best_combination[0], tau_neg_prime) + + # Do a statistical correction to get a target satisfying precision + candidate_thresholds = [1] + for pair in sorted_pairs: + possible_threshold = pair[0] + Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold] + mean_z = np.mean(Z) if Z else 0 + std_z = np.std(Z) if Z else 0 + p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs)) + if p_l > precision_target: + candidate_thresholds.append(possible_threshold) + + best_combination = (max(best_combination[1], min(candidate_thresholds)), best_combination[1]) + oracle_calls = sum(1 for x in proxy_scores if best_combination[0] > x > best_combination[1]) + + no_correction_sorted_pairs = [tup[:2] + (1,) for tup in sorted_pairs] + lotus.logger.info(f"Sample recall: {recall(best_combination[0], best_combination[1], no_correction_sorted_pairs)}") + lotus.logger.info(f"Sample precision: {precision(best_combination[0], best_combination[1], sorted_pairs)}") + + return best_combination, oracle_calls + +def calibrate_sem_sim_join(true_score: list[float]): + true_score = np.clip(true_score, 0, 1) + return true_score \ No newline at end of file diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index ed81d958..aafd6aa6 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -6,6 +6,7 @@ from lotus.templates import task_instructions from lotus.types import SemanticFilterOutput +from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -59,6 +60,52 @@ def sem_filter( return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=raw_logprobs if logprobs else None) +def learn_filter_cascade_thresholds( + sample_df_txt: str, + lm: lotus.models.LM, + formatted_usr_instr: str, + default: bool, + recall_target: float, + precision_target: float, + delta: float, + helper_true_probs: list[float], + sample_correction_factors: list[float], + examples_df_txt: str | None = None, + examples_answers: str | None = None, + cot_reasoning: list | None = None, + strategy: str | None = None, +) -> tuple[float, float]: + """Automatically learns the cascade thresholds for a cascade + filter given a sample of data and doing a search across threshold + to see what threshold gives the best accuracy.""" + + try: + large_outputs = sem_filter( + sample_df_txt, + lm, + formatted_usr_instr, + default=default, + examples_df_txt=examples_df_txt, + examples_answers=examples_answers, + cot_reasoning=cot_reasoning, + strategy=strategy, + ).outputs + + best_combination, _ = learn_cascade_thresholds( + proxy_scores=helper_true_probs, + oracle_outputs=large_outputs, + sample_correction_factors=sample_correction_factors, + recall_target=recall_target, + precision_target=precision_target, + delta=delta + ) + + lotus.logger.info(f"Learned cascade thresholds: {best_combination}") + return best_combination + + except Exception as e: + lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") + return None @pd.api.extensions.register_dataframe_accessor("sem_filter") class SemFilterDataframe: @@ -85,7 +132,10 @@ def __call__( helper_examples: pd.DataFrame | None = None, strategy: str | None = None, helper_strategy: str | None = None, - cascade_threshold: float | None = None, + learn_cascade_threshold_sample_percentage: int | None = None, + recall_target: float | None = None, + precision_target: float | None = None, + failure_probability: float | None = None, return_stats: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ @@ -100,7 +150,10 @@ def __call__( helper_examples (pd.DataFrame | None): The helper examples dataframe. Defaults to None. strategy (str | None): The reasoning strategy. Defaults to None. helper_strategy (str | None): The reasoning strategy for helper. Defaults to None. - cascade_threshold (float | None): The threshold for cascading. Defaults to None. + learn_cascade_threshold_sample_size (Optional[int]): The percentage of samples from which to learn thresholds when cascading. + recall_target (float | None): The specified recall target. + precision_target (float | None): The specified precision target. + failure_probability (float | None): The specified failure probability for precision/recall targets. return_stats (bool): Whether to return statistics. Defaults to False. Returns: @@ -132,10 +185,8 @@ def __call__( return_explanations = True cot_reasoning = examples["Reasoning"].tolist() - if cascade_threshold is not None: - stats["filters_resolved_by_helper_model"] = 0 - stats["filters_resolved_by_large_model"] = 0 - + pos_cascade_threshold, neg_cascade_threshold = None, None + if learn_cascade_threshold_sample_percentage is not None: # Get few-shot examples for small LM helper_examples_df_txt = None helper_examples_answers = None @@ -147,6 +198,15 @@ def __call__( if helper_strategy == "cot": helper_cot_reasoning = examples["Reasoning"].tolist() + + if learn_cascade_threshold_sample_percentage and lotus.settings.helper_lm: + if helper_strategy == "cot": + lotus.logger.error("CoT not supported for helper models in cascades.") + raise Exception + + if recall_target is None or precision_target is None or failure_probability is None: + lotus.logger.error("Recall target, precision target, and confidence need to be specified for learned thresholds.") + raise Exception # Run small LM and get logits helper_output = sem_filter( @@ -160,19 +220,51 @@ def __call__( logprobs=True, strategy=helper_strategy, ) + helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs + _, _, helper_true_probs = lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) - high_conf_idxs = set() - helper_tokens, helper_confidences = lotus.settings.helper_lm.format_logprobs_for_cascade( - helper_output.logprobs + helper_true_probs = calibrate_llm_logprobs(helper_true_probs) + + sample_indices, correction_factors = importance_sampling(helper_true_probs, learn_cascade_threshold_sample_percentage) + sample_df = self._obj.loc[sample_indices] + sample_df_txt = task_instructions.df2text(sample_df, col_li) + sample_helper_true_probs = [helper_true_probs[i] for i in sample_indices] + sample_correction_factors = correction_factors[sample_indices] + + pos_cascade_threshold, neg_cascade_threshold = learn_filter_cascade_thresholds( + sample_df_txt=sample_df_txt, + lm=lotus.settings.lm, + formatted_usr_instr=formatted_usr_instr, + default=default, + recall_target=recall_target, + precision_target=precision_target, + delta=failure_probability/2, + helper_true_probs=sample_helper_true_probs, + sample_correction_factors=sample_correction_factors, + examples_df_txt=examples_df_txt, + examples_answers=examples_answers, + cot_reasoning=cot_reasoning, + strategy=strategy, ) + stats["pos_cascade_threshold"] = pos_cascade_threshold + stats["neg_cascade_threshold"] = neg_cascade_threshold + + if pos_cascade_threshold is not None and neg_cascade_threshold is not None: + stats["filters_resolved_by_helper_model"] = 0 + stats["filters_resolved_by_large_model"] = 0 + + high_conf_idxs = set() + # Find where true/false is said and look at confidence - for idx_i, (tokens, confidences) in enumerate(zip(helper_tokens, helper_confidences)): - for idx_j in range(len(tokens) - 1, -1, -1): - if tokens[idx_j].strip(" \n").lower() in ["true", "false"]: - conf = confidences[idx_j] - if conf >= cascade_threshold: - high_conf_idxs.add(idx_i) + for idx_i in range(len(helper_true_probs)): + true_prob = helper_true_probs[idx_i] + if true_prob >= pos_cascade_threshold or true_prob <= neg_cascade_threshold: + high_conf_idxs.add(idx_i) + helper_outputs[idx_i] = True if true_prob >= pos_cascade_threshold else False if true_prob <= neg_cascade_threshold else helper_outputs[idx_i] + + lotus.logger.info(f"Num routed to smaller model: {len(high_conf_idxs)}") + stats["num_routed_to_helper_model"] = len(high_conf_idxs) outputs: list[bool] = [False] * len(df_txt) raw_outputs: list[str] = [""] * len(df_txt) @@ -182,12 +274,12 @@ def __call__( x is None for x in helper_output.explanations ) for idx in high_conf_idxs: - outputs[idx] = helper_output.outputs[idx] + outputs[idx] = helper_outputs[idx] raw_outputs[idx] = helper_output.raw_outputs[idx] explanations[idx] = helper_output.explanations[idx] # Send low confidence samples to large LM if any - low_conf_idxs = sorted([i for i in range(len(helper_output.outputs)) if i not in high_conf_idxs]) + low_conf_idxs = sorted([i for i in range(len(helper_outputs)) if i not in high_conf_idxs]) low_conf_df_txt = [df_txt[idx] for idx in low_conf_idxs] if low_conf_idxs: large_output = sem_filter( From 231a4469936f09ba308995704a35d1e2f92a73b8 Mon Sep 17 00:00:00 2001 From: Parth Asawa Date: Mon, 28 Oct 2024 20:37:18 -0700 Subject: [PATCH 3/3] Add return types to address Sid's comment --- lotus/sem_ops/cascade_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 7e736a44..3302a493 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -6,7 +6,7 @@ def importance_sampling( proxy_scores: list[float], sample_percentage: float, -): +) -> tuple[list[int], list[float]]: """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) @@ -18,7 +18,7 @@ def importance_sampling( return sample_indices, correction_factors -def calibrate_llm_logprobs(true_probs: list[float]): +def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = 50 quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) @@ -33,7 +33,7 @@ def learn_cascade_thresholds( recall_target: float, precision_target: float, delta: float -): +) -> tuple[tuple[float, float], int]: """Learns cascade thresholds given targets and proxy scores, oracle outputs over the sample, and correction factors for the sample.""" @@ -105,6 +105,6 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: return best_combination, oracle_calls -def calibrate_sem_sim_join(true_score: list[float]): +def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: true_score = np.clip(true_score, 0, 1) return true_score \ No newline at end of file