From 4b25ceccf78e0cf4d01340e9f26ecc6532319d2b Mon Sep 17 00:00:00 2001 From: melissa-pan Date: Thu, 14 Nov 2024 23:01:36 -0800 Subject: [PATCH] fix ruff lint + mypy --- examples/op_examples/join_cascade.py | 1 - lotus/sem_ops/sem_join.py | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/op_examples/join_cascade.py b/examples/op_examples/join_cascade.py index 78674669..06535c0e 100644 --- a/examples/op_examples/join_cascade.py +++ b/examples/op_examples/join_cascade.py @@ -3,7 +3,6 @@ import lotus from lotus.models import LM, SentenceTransformersRM - lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index cd3f034d..dde7f733 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -3,11 +3,11 @@ import pandas as pd import lotus -from lotus.types import SemanticJoinOutput from lotus.templates import task_instructions +from lotus.types import SemanticJoinOutput +from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds from .sem_filter import sem_filter -from .cascade_utils import importance_sampling, learn_cascade_thresholds, calibrate_sem_sim_join def sem_join( @@ -440,7 +440,7 @@ def learn_join_cascade_threshold( default: bool = True, strategy: str | None = None, sampling_range: tuple[int, int] | None = None, -) -> tuple[float, float, int]: +) -> tuple[float, float, float]: """ Extract a small sample of the data and find the optimal threshold pair that satisfies the recall and precision target. @@ -528,8 +528,8 @@ def __call__( default: bool = True, recall_target: float | None = None, precision_target: float | None = None, - sampling_percentage: float | None = 0.1, - failure_probability: float | None = 0.2, + sampling_percentage: float = 0.1, + failure_probability: float = 0.2, map_instruction: str | None = None, map_examples: pd.DataFrame | None = None, sampling_range: tuple[int, int] | None = None, @@ -633,8 +633,8 @@ def __call__( join_instruction, recall_target, precision_target, - sampling_percentage, - failure_probability, + sampling_percentage=sampling_percentage, + failure_probability=failure_probability, examples_df_txt=examples_df_txt, examples_answers=examples_answers, map_instruction=map_instruction,