diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 681e7ca9..2bee2c77 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -20,14 +20,16 @@ def importance_sampling( return sample_indices, correction_factors + def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) + true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles true_probs = list(np.clip(true_probs, 0, 1)) return true_probs + def learn_cascade_thresholds( proxy_scores: list[float], oracle_outputs: list[bool], diff --git a/lotus/settings.py b/lotus/settings.py index 765a04ea..a928880c 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -115,4 +115,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() -settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) \ No newline at end of file +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50)