From 19b35968f30786eb100463ed3273ac76ca94d3c8 Mon Sep 17 00:00:00 2001 From: Parth Asawa <37985050+pgasawa@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:35:00 -0700 Subject: [PATCH] Add default cascade settings (#25) Moves num quantiles (default 50) and IS weight (default 0.5) to the default settings. --- lotus/sem_ops/cascade_utils.py | 5 +++-- lotus/settings.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 3302a493..0e8eabdf 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -10,7 +10,8 @@ def importance_sampling( """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) - w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores) + is_weight = lotus.settings.cascade_is_weight + w = is_weight * w / np.sum(w) + (1 - is_weight) * np.ones((len(proxy_scores))) / len(proxy_scores) indices = np.arange(len(proxy_scores)) sample_size = (int) (sample_percentage * len(proxy_scores)) sample_indices = np.random.choice(indices, sample_size, p=w) @@ -20,7 +21,7 @@ def importance_sampling( def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" - num_quantiles = 50 + num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) true_probs = np.clip(true_probs, 0, 1) diff --git a/lotus/settings.py b/lotus/settings.py index eb0e9feb..ee8acba4 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -113,3 +113,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) \ No newline at end of file