diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 3302a493..0e8eabdf 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -10,7 +10,8 @@ def importance_sampling( """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) - w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores) + is_weight = lotus.settings.cascade_is_weight + w = is_weight * w / np.sum(w) + (1 - is_weight) * np.ones((len(proxy_scores))) / len(proxy_scores) indices = np.arange(len(proxy_scores)) sample_size = (int) (sample_percentage * len(proxy_scores)) sample_indices = np.random.choice(indices, sample_size, p=w) @@ -20,7 +21,7 @@ def importance_sampling( def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" - num_quantiles = 50 + num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) true_probs = np.clip(true_probs, 0, 1) diff --git a/lotus/settings.py b/lotus/settings.py index eb0e9feb..ee8acba4 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -113,3 +113,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) \ No newline at end of file