Skip to content

Commit

Permalink
Core code
Browse files Browse the repository at this point in the history
  • Loading branch information
pgasawa committed Oct 28, 2024
1 parent d2b8cba commit 3c2b3fd
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 19 deletions.
108 changes: 107 additions & 1 deletion examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,116 @@
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
"Data Structures and Algorithms",
"Machine Learning",
"Artificial Intelligence",
"Natural Language Processing",
"Introduction to Robotics",
"Control Systems",
"Linear Algebra and Differential Equations",
"Database Systems",
"Cloud Computing",
"Software Engineering",
"Operating Systems",
"Discrete Mathematics",
"Numerical Methods",
"Wireless Communication Systems",
"Embedded Systems",
"Advanced Computer Architecture",
"Graph Theory",
"Cryptography and Network Security",
"Big Data Analytics",
"Deep Learning",
"Organic Chemistry",
"Molecular Biology",
"Environmental Science",
"Genetics and Evolution",
"Human Physiology",
"Introduction to Anthropology",
"Cultural Studies",
"Political Theory",
"Macroeconomics",
"Microeconomics",
"Introduction to Sociology",
"Developmental Psychology",
"Cognitive Science",
"Introduction to Philosophy",
"Ethics and Moral Philosophy",
"History of Western Civilization",
"Art History: Renaissance to Modern",
"World Literature",
"Introduction to Journalism",
"Public Speaking and Communication",
"Creative Writing",
"Music Theory",
"Introduction to Theater",
"Film Studies",
"Environmental Policy and Law",
"Sustainability and Renewable Energy",
"Urban Planning and Design",
"International Relations",
"Marketing Principles",
"Organizational Behavior",
"Financial Accounting",
"Corporate Finance",
"Business Law",
"Supply Chain Management",
"Operations Research",
"Entrepreneurship and Innovation",
"Introduction to Psychology",
"Health Economics",
"Biostatistics",
"Social Work Practice",
"Public Health Policy",
"Environmental Ethics",
"History of Political Thought",
"Quantitative Research Methods",
"Comparative Politics",
"Urban Economics",
"Behavioral Economics",
"Sociology of Education",
"Social Psychology",
"Gender Studies",
"Media and Communication Studies",
"Advertising and Brand Strategy",
"Sports Management",
"Introduction to Archaeology",
"Ecology and Conservation Biology",
"Marine Biology",
"Geology and Earth Science",
"Astronomy and Astrophysics",
"Introduction to Meteorology",
"Introduction to Oceanography",
"Quantum Physics",
"Thermodynamics",
"Fluid Mechanics",
"Solid State Physics",
"Classical Mechanics",
"Introduction to Civil Engineering",
"Material Science and Engineering",
"Structural Engineering",
"Environmental Engineering",
"Energy Systems Engineering",
"Aerodynamics",
"Heat Transfer",
"Renewable Energy Systems",
"Transportation Engineering",
"Water Resources Management",
"Principles of Accounting",
"Project Management",
"International Business",
"Business Analytics",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df, stats = df.sem_filter(user_instruction, cascade_threshold=0.95, return_stats=True)
df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)
print(df)
print(stats)
39 changes: 38 additions & 1 deletion lotus/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def handle_chat_request(
then a list of logprobs is also returned.
"""
if kwargs.get("logprobs", False):
kwargs["top_logprobs"] = 1
kwargs["top_logprobs"] = 10

kwargs = {**self.kwargs, **kwargs}
kwargs["messages"] = messages
Expand Down Expand Up @@ -127,6 +127,8 @@ def handle_completion_request(

kwargs = {**self.kwargs, **kwargs}
kwargs["prompt"] = prompt
if kwargs.get("logprobs", False):
kwargs["logprobs"] = 10
response = self.completion_request(**kwargs)

choices = response["choices"]
Expand Down Expand Up @@ -254,6 +256,41 @@ def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]],
all_confidences.append(confidences)

return all_tokens, all_confidences

def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]:
all_tokens = []
all_confidences = []
all_true_probs = []
for idx in range(len(logprobs)):
if self.provider == "vllm":
tokens = logprobs[idx]["tokens"]
confidences = np.exp(logprobs[idx]["token_logprobs"])
top_logprobs = logprobs[idx]["top_logprobs"][0]
if 'True' in top_logprobs and 'False' in top_logprobs:
true_prob = np.exp(top_logprobs['True'])
false_prob = np.exp(top_logprobs['False'])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if 'True' in top_logprobs else 0)

elif self.provider == "openai":
content = logprobs[idx]["content"]
tokens = [content[t_idx]["token"] for t_idx in range(len(content))]
confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))])
top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]}

true_prob, false_prob = 0, 0
if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs:
true_prob = np.exp(top_logprobs['True'])
false_prob = np.exp(top_logprobs['False'])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if 'True' in top_logprobs else 0)

all_tokens.append(tokens)
all_confidences.append(confidences)

return all_tokens, all_confidences, all_true_probs

def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
"""Send chat request to OpenAI server.
Expand Down
110 changes: 110 additions & 0 deletions lotus/sem_ops/cascade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np

import lotus


def importance_sampling(
proxy_scores: list[float],
sample_percentage: float,
):
"""Uses importance sampling and returns the list of indices from which to learn cascade thresholds."""

w = np.sqrt(proxy_scores)
w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores)
indices = np.arange(len(proxy_scores))
sample_size = (int) (sample_percentage * len(proxy_scores))
sample_indices = np.random.choice(indices, sample_size, p=w)
correction_factors = (1/len(proxy_scores)) / w

return sample_indices, correction_factors

def calibrate_llm_logprobs(true_probs: list[float]):
"""Transforms true probabilities to calibrate LLM proxies."""
num_quantiles = 50
quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1))
true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles)
true_probs = np.clip(true_probs, 0, 1)
return true_probs

def learn_cascade_thresholds(
proxy_scores: list[float],
oracle_outputs: list[float],
sample_correction_factors: list[float],
recall_target: float,
precision_target: float,
delta: float
):
"""Learns cascade thresholds given targets and proxy scores,
oracle outputs over the sample, and correction factors for the
sample."""

def UB(mean, std_dev, s, delta):
return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5)

def LB(mean, std_dev, s, delta):
return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5)

def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool:
helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold]
sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold]
total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs)
recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct
return recall

def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool:
helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold]
sent_to_oracle = [x for x in sorted_pairs if pos_threshold > x[0] > neg_threshold]
oracle_positive = sum(x[1] for x in sent_to_oracle)
true_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + oracle_positive
predicted_positives = sum(1 for x in helper_accepted if x[0] >= pos_threshold) + oracle_positive
precision = true_positives / predicted_positives if predicted_positives > 0 else 0
return precision

# Pair helper model probabilities with helper correctness and oracle answer
paired_data = list(zip(proxy_scores, oracle_outputs, sample_correction_factors))
sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True)
sample_size = len(sorted_pairs)

best_combination = (1,0) # initial tau_+, tau_-

# Find tau_negative based on recall
tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target)
best_combination = (best_combination[0], tau_neg_0)

# Do a statistical correction to get a new target recall
Z1 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] >= best_combination[1]]
Z2 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] < best_combination[1]]

mean_z1 = np.mean(Z1) if Z1 else 0
std_z1 = np.std(Z1) if Z1 else 0
mean_z2 = np.mean(Z2) if Z2 else 0
std_z2 = np.std(Z2) if Z2 else 0

corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2))
corrected_recall_target = min(1, corrected_recall_target)
tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target)
best_combination = (best_combination[0], tau_neg_prime)

# Do a statistical correction to get a target satisfying precision
candidate_thresholds = [1]
for pair in sorted_pairs:
possible_threshold = pair[0]
Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold]
mean_z = np.mean(Z) if Z else 0
std_z = np.std(Z) if Z else 0
p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs))
if p_l > precision_target:
candidate_thresholds.append(possible_threshold)

best_combination = (max(best_combination[1], min(candidate_thresholds)), best_combination[1])
oracle_calls = sum(1 for x in proxy_scores if best_combination[0] > x > best_combination[1])

no_correction_sorted_pairs = [tup[:2] + (1,) for tup in sorted_pairs]
lotus.logger.info(f"Sample recall: {recall(best_combination[0], best_combination[1], no_correction_sorted_pairs)}")
lotus.logger.info(f"Sample precision: {precision(best_combination[0], best_combination[1], sorted_pairs)}")

return best_combination, oracle_calls

def calibrate_sem_sim_join(true_score: list[float]):
true_score = np.clip(true_score, 0, 1)
return true_score
Loading

0 comments on commit 3c2b3fd

Please sign in to comment.