Skip to content

Commit

Permalink
Implements learned filter cascade (#23)
Browse files Browse the repository at this point in the history
- Adds learned filter cascade code from research experiments
- Updates op_examples and github tests for the filter cascade operations
  • Loading branch information
pgasawa authored Oct 29, 2024
1 parent 8d61eb2 commit 9dab975
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 33 deletions.
76 changes: 62 additions & 14 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,83 @@ def test_filter_operation(setup_models):
lotus.settings.configure(lm=gpt_4o_mini)

# Test filter operation on an easy dataframe
data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
data = {"Text": ["I am really excited to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"
filtered_df = df.sem_filter(user_instruction)

expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
expected_df = pd.DataFrame({"Text": ["I am really excited to go to class today!"]})
assert filtered_df.equals(expected_df)


def test_filter_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
data = {
"Text": [
# Positive examples
"I am really excited to go to class today!",
"Today is going to be an amazing day!",
"I absolutely love the new project I am working on.",
"Feeling so grateful for everything I have.",
"I can't wait to see my friends this weekend!",
"The weather is beautiful, and I feel fantastic.",
"Just received some great news about my promotion!",
"I'm so happy to have such supportive colleagues.",
"I'm thrilled to be learning something new every day.",
"Life is really good right now, and I feel blessed.",
"I am proud of all the progress I've made this year.",
"Today was productive, and I feel accomplished.",
"I’m really enjoying my workout routine lately!",
"Got a compliment from my manager today, feeling awesome!",
"Looking forward to spending time with family tonight.",
"Just finished a great book and feel inspired!",
"Had a lovely meal with friends, life is good!",
"Everything is going as planned, couldn't be happier.",
"Feeling super motivated and ready to take on challenges!",
"I appreciate all the small things that bring me joy.",

# Negative examples
"I am very sad.",
"Today has been really tough; I feel exhausted.",
"I'm feeling pretty down about how things are going.",
"I’m overwhelmed with all these challenges.",
"It’s hard to stay positive when things keep going wrong.",
"I feel so alone and unappreciated.",
"My energy is low, and nothing seems to cheer me up.",
"Feeling anxious about everything lately.",
"I’m disappointed with the way my project turned out.",
"Today has been one of those days where everything goes wrong.",
"Life feels really overwhelming right now.",
"I can't seem to find any motivation these days.",
"I’m worried about the future and what it holds.",
"It's been a stressful day, and I feel mentally drained.",
"I feel like I'm falling behind everyone else.",
"Just can't seem to catch a break recently.",
"I’m really struggling to keep up with all my responsibilities.",
"Had an argument with a close friend, feeling hurt.",
"I don’t feel supported by my team at work.",
"Life has been tough lately, and I’m feeling down.",
]
}

df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"

# All filters resolved by the helper model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=0, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 0, stats
assert stats["filters_resolved_by_helper_model"] == 2, stats
expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
assert filtered_df.equals(expected_df)

# All filters resolved by the large model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=1.01, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 2, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats
assert filtered_df.equals(expected_df)
filtered_df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)

assert "I am really excited to go to class today!" in filtered_df["Text"].values
assert "I am very sad" not in filtered_df["Text"].values
assert stats["filters_resolved_by_helper_model"] > 0, stats


def test_top_k(setup_models):
Expand Down
108 changes: 107 additions & 1 deletion examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,116 @@
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
"Data Structures and Algorithms",
"Machine Learning",
"Artificial Intelligence",
"Natural Language Processing",
"Introduction to Robotics",
"Control Systems",
"Linear Algebra and Differential Equations",
"Database Systems",
"Cloud Computing",
"Software Engineering",
"Operating Systems",
"Discrete Mathematics",
"Numerical Methods",
"Wireless Communication Systems",
"Embedded Systems",
"Advanced Computer Architecture",
"Graph Theory",
"Cryptography and Network Security",
"Big Data Analytics",
"Deep Learning",
"Organic Chemistry",
"Molecular Biology",
"Environmental Science",
"Genetics and Evolution",
"Human Physiology",
"Introduction to Anthropology",
"Cultural Studies",
"Political Theory",
"Macroeconomics",
"Microeconomics",
"Introduction to Sociology",
"Developmental Psychology",
"Cognitive Science",
"Introduction to Philosophy",
"Ethics and Moral Philosophy",
"History of Western Civilization",
"Art History: Renaissance to Modern",
"World Literature",
"Introduction to Journalism",
"Public Speaking and Communication",
"Creative Writing",
"Music Theory",
"Introduction to Theater",
"Film Studies",
"Environmental Policy and Law",
"Sustainability and Renewable Energy",
"Urban Planning and Design",
"International Relations",
"Marketing Principles",
"Organizational Behavior",
"Financial Accounting",
"Corporate Finance",
"Business Law",
"Supply Chain Management",
"Operations Research",
"Entrepreneurship and Innovation",
"Introduction to Psychology",
"Health Economics",
"Biostatistics",
"Social Work Practice",
"Public Health Policy",
"Environmental Ethics",
"History of Political Thought",
"Quantitative Research Methods",
"Comparative Politics",
"Urban Economics",
"Behavioral Economics",
"Sociology of Education",
"Social Psychology",
"Gender Studies",
"Media and Communication Studies",
"Advertising and Brand Strategy",
"Sports Management",
"Introduction to Archaeology",
"Ecology and Conservation Biology",
"Marine Biology",
"Geology and Earth Science",
"Astronomy and Astrophysics",
"Introduction to Meteorology",
"Introduction to Oceanography",
"Quantum Physics",
"Thermodynamics",
"Fluid Mechanics",
"Solid State Physics",
"Classical Mechanics",
"Introduction to Civil Engineering",
"Material Science and Engineering",
"Structural Engineering",
"Environmental Engineering",
"Energy Systems Engineering",
"Aerodynamics",
"Heat Transfer",
"Renewable Energy Systems",
"Transportation Engineering",
"Water Resources Management",
"Principles of Accounting",
"Project Management",
"International Business",
"Business Analytics",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df, stats = df.sem_filter(user_instruction, cascade_threshold=0.95, return_stats=True)
df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)
print(df)
print(stats)
39 changes: 38 additions & 1 deletion lotus/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def handle_chat_request(
then a list of logprobs is also returned.
"""
if kwargs.get("logprobs", False):
kwargs["top_logprobs"] = 1
kwargs["top_logprobs"] = 10

kwargs = {**self.kwargs, **kwargs}
kwargs["messages"] = messages
Expand Down Expand Up @@ -127,6 +127,8 @@ def handle_completion_request(

kwargs = {**self.kwargs, **kwargs}
kwargs["prompt"] = prompt
if kwargs.get("logprobs", False):
kwargs["logprobs"] = 10
response = self.completion_request(**kwargs)

choices = response["choices"]
Expand Down Expand Up @@ -254,6 +256,41 @@ def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]],
all_confidences.append(confidences)

return all_tokens, all_confidences

def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]:
all_tokens = []
all_confidences = []
all_true_probs = []
for idx in range(len(logprobs)):
if self.provider == "vllm":
tokens = logprobs[idx]["tokens"]
confidences = np.exp(logprobs[idx]["token_logprobs"])
top_logprobs = logprobs[idx]["top_logprobs"][0]
if 'True' in top_logprobs and 'False' in top_logprobs:
true_prob = np.exp(top_logprobs['True'])
false_prob = np.exp(top_logprobs['False'])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if 'True' in top_logprobs else 0)

elif self.provider == "openai":
content = logprobs[idx]["content"]
tokens = [content[t_idx]["token"] for t_idx in range(len(content))]
confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))])
top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]}

true_prob, false_prob = 0, 0
if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs:
true_prob = np.exp(top_logprobs['True'])
false_prob = np.exp(top_logprobs['False'])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if 'True' in top_logprobs else 0)

all_tokens.append(tokens)
all_confidences.append(confidences)

return all_tokens, all_confidences, all_true_probs

def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
"""Send chat request to OpenAI server.
Expand Down
110 changes: 110 additions & 0 deletions lotus/sem_ops/cascade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np

import lotus


def importance_sampling(
proxy_scores: list[float],
sample_percentage: float,
) -> tuple[list[int], list[float]]:
"""Uses importance sampling and returns the list of indices from which to learn cascade thresholds."""

w = np.sqrt(proxy_scores)
w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores)
indices = np.arange(len(proxy_scores))
sample_size = (int) (sample_percentage * len(proxy_scores))
sample_indices = np.random.choice(indices, sample_size, p=w)
correction_factors = (1/len(proxy_scores)) / w

return sample_indices, correction_factors

def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]:
"""Transforms true probabilities to calibrate LLM proxies."""
num_quantiles = 50
quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1))
true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles)
true_probs = np.clip(true_probs, 0, 1)
return true_probs

def learn_cascade_thresholds(
proxy_scores: list[float],
oracle_outputs: list[float],
sample_correction_factors: list[float],
recall_target: float,
precision_target: float,
delta: float
) -> tuple[tuple[float, float], int]:
"""Learns cascade thresholds given targets and proxy scores,
oracle outputs over the sample, and correction factors for the
sample."""

def UB(mean, std_dev, s, delta):
return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5)

def LB(mean, std_dev, s, delta):
return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5)

def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool:
helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold]
sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold]
total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs)
recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct
return recall

def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool:
helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold]
sent_to_oracle = [x for x in sorted_pairs if pos_threshold > x[0] > neg_threshold]
oracle_positive = sum(x[1] for x in sent_to_oracle)
true_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + oracle_positive
predicted_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold) + oracle_positive
precision = true_positives / predicted_positives if predicted_positives > 0 else 0
return precision

# Pair helper model probabilities with helper correctness and oracle answer
paired_data = list(zip(proxy_scores, oracle_outputs, sample_correction_factors))
sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True)
sample_size = len(sorted_pairs)

best_combination = (1,0) # initial tau_+, tau_-

# Find tau_negative based on recall
tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target)
best_combination = (best_combination[0], tau_neg_0)

# Do a statistical correction to get a new target recall
Z1 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] >= best_combination[1]]
Z2 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] < best_combination[1]]

mean_z1 = np.mean(Z1) if Z1 else 0
std_z1 = np.std(Z1) if Z1 else 0
mean_z2 = np.mean(Z2) if Z2 else 0
std_z2 = np.std(Z2) if Z2 else 0

corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2))
corrected_recall_target = min(1, corrected_recall_target)
tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target)
best_combination = (best_combination[0], tau_neg_prime)

# Do a statistical correction to get a target satisfying precision
candidate_thresholds = [1]
for pair in sorted_pairs:
possible_threshold = pair[0]
Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold]
mean_z = np.mean(Z) if Z else 0
std_z = np.std(Z) if Z else 0
p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs))
if p_l > precision_target:
candidate_thresholds.append(possible_threshold)

best_combination = (max(best_combination[1], min(candidate_thresholds)), best_combination[1])
oracle_calls = sum(1 for x in proxy_scores if best_combination[0] > x > best_combination[1])

no_correction_sorted_pairs = [tup[:2] + (1,) for tup in sorted_pairs]
lotus.logger.info(f"Sample recall: {recall(best_combination[0], best_combination[1], no_correction_sorted_pairs)}")
lotus.logger.info(f"Sample precision: {precision(best_combination[0], best_combination[1], sorted_pairs)}")

return best_combination, oracle_calls

def calibrate_sem_sim_join(true_score: list[float]) -> list[float]:
true_score = np.clip(true_score, 0, 1)
return true_score
Loading

0 comments on commit 9dab975

Please sign in to comment.