Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sem_join + sem_filter pbars #46

Merged
merged 11 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

################################################################################
# Setup
Expand Down Expand Up @@ -270,10 +270,12 @@ def test_filter_cascade(setup_models):
# All filters resolved by the helper model
filtered_df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
cascade_args=CascadeArgs(
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
),
return_stats=True,
)

Expand All @@ -286,10 +288,12 @@ def test_filter_cascade(setup_models):
def test_join_cascade(setup_models):
models = setup_models
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
lotus.settings.configure(lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42)
lotus.settings.configure(
lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42,
)

data1 = {
"School": [
Expand All @@ -308,37 +312,41 @@ def test_join_cascade(setup_models):
"Yale University",
"Cornell University",
"University of Pennsylvania",
]}
]
}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
expected_pairs = [("University of California, Berkeley", "Public School"), ("Stanford University", "Private School")]
expected_pairs = [
("University of California, Berkeley", "Public School"),
("Stanford University", "Private School"),
]

# Cascade join
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=0.7, precision_target=0.7), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_helper_model"] > 0, stats

# All joins resolved by the large model
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=1.0, precision_target=1.0),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=1.0, precision_target=1.0), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"], stats # helper negative still can still meet the precision target
assert (
stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"]
), stats # helper negative still can still meet the precision target
assert stats["join_helper_positive"] == 0, stats


Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install ruff==0.7.2

- name: Run ruff
run: ruff check .
run: ruff check lotus/

mypy:
name: Type Check
Expand Down
5 changes: 2 additions & 3 deletions examples/op_examples/agg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pandas as pd

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")

lotus.settings.configure(lm=lm, rm=rm)
lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
14 changes: 6 additions & 8 deletions examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import lotus
from lotus.models import LM
from lotus.types import CascadeArgs


gpt_35_turbo = LM("gpt-3.5-turbo")
gpt_4o = LM("gpt-4o")
Expand Down Expand Up @@ -116,13 +118,9 @@
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)

cascade_args = CascadeArgs(recall_target=0.9, precision_target=0.9, sampling_percentage=0.5, failure_probability=0.2)

df, stats = df.sem_filter(user_instruction=user_instruction, cascade_args=cascade_args, return_stats=True)
print(df)
print(stats)
4 changes: 2 additions & 2 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
Expand Down Expand Up @@ -124,7 +124,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
cascade_args = CascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand Down
2 changes: 1 addition & 1 deletion lotus/models/cross_encoder_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def __init__(
self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs

def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput:
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size)
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size, show_progress_bar=False)
indices = [int(result["corpus_id"]) for result in results]
return RerankerOutput(indices=indices)
26 changes: 22 additions & 4 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def __init__(
self.cache = Cache(max_cache_size)

def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
self,
messages: list[list[dict[str, str]]],
show_progress_bar: bool = True,
progress_bar_desc: str = "Processing uncached messages",
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

Expand All @@ -59,7 +63,9 @@ def __call__(
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs)
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
Expand All @@ -74,12 +80,24 @@ def __call__(

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
def _process_uncached_messages(self, uncached_data, all_kwargs, show_progress_bar, progress_bar_desc):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
total_calls = len(uncached_data)

pbar = tqdm(
total=total_calls,
desc=progress_bar_desc,
disable=not show_progress_bar,
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
for i in range(0, total_calls, self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))

pbar.update(len(batch))
pbar.close()

return uncached_responses

def _cache_response(self, response, hash):
Expand Down
2 changes: 1 addition & 1 deletion lotus/models/sentence_transformers_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]:
batch = docs[i : i + self.max_batch_size]
_batch = convert_to_base_data(batch)
torch_embeddings = self.transformer.encode(
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings, show_progress_bar=False
)
assert isinstance(torch_embeddings, torch.Tensor)
cpu_embeddings = torch_embeddings.cpu().numpy()
Expand Down
18 changes: 12 additions & 6 deletions lotus/sem_ops/cascade_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def importance_sampling(
sample_size = int(sample_percentage * len(proxy_scores))
sample_indices = np.random.choice(indices, sample_size, p=sample_w)

correction_factors = (1/len(proxy_scores)) / w
correction_factors = (1 / len(proxy_scores)) / w

return sample_indices, correction_factors

Expand Down Expand Up @@ -65,8 +65,14 @@ def recall(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[
sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold]
total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs)
recall = (
sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)
) / total_correct if total_correct > 0 else 0.0
(
sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1])
+ sum(x[1] * x[2] for x in sent_to_oracle)
)
/ total_correct
if total_correct > 0
else 0.0
)
return recall

def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float:
Expand All @@ -80,8 +86,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tup

def calculate_tau_neg(sorted_pairs: list[tuple[float, bool, float]], tau_pos: float, recall_target: float) -> float:
return max(
(x[0] for x in sorted_pairs[::-1] if recall(tau_pos, x[0], sorted_pairs) >= recall_target),
default=0
(x[0] for x in sorted_pairs[::-1] if recall(tau_pos, x[0], sorted_pairs) >= recall_target), default=0
)

# Pair helper model probabilities with helper correctness and oracle answer
Expand Down Expand Up @@ -135,6 +140,7 @@ def calculate_tau_neg(sorted_pairs: list[tuple[float, bool, float]], tau_pos: fl

return best_combination, oracle_calls


def calibrate_sem_sim_join(true_score: list[float]) -> list[float]:
true_score = list(np.clip(true_score, 0, 1))
return true_score
return true_score
7 changes: 5 additions & 2 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def sem_agg(
user_instruction: str,
partition_ids: list[int],
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -115,7 +116,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)

lm_output: LMOutput = model(batch)
lm_output: LMOutput = model(batch, progress_bar_desc=progress_bar_desc)

summaries = lm_output.outputs
partition_ids = new_partition_ids
Expand Down Expand Up @@ -149,6 +150,7 @@ def __call__(
suffix: str = "_output",
group_by: list[str] | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -178,7 +180,7 @@ def __call__(
grouped = self._obj.groupby(group_by)
new_df = pd.DataFrame()
for name, group in grouped:
res = group.sem_agg(user_instruction, all_cols, suffix, None)
res = group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)
new_df = pd.concat([new_df, res])
return new_df

Expand All @@ -200,6 +202,7 @@ def __call__(
formatted_usr_instr,
partition_ids,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)

# package answer in a dataframe
Expand Down
5 changes: 4 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def sem_extract(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand Down Expand Up @@ -48,7 +49,7 @@ def sem_extract(
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc)

# post process results
postprocess_output = postprocessor(lm_output.outputs)
Expand Down Expand Up @@ -79,6 +80,7 @@ def __call__(
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -108,6 +110,7 @@ def __call__(
extract_quotes=extract_quotes,
postprocessor=postprocessor,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)

new_df = self._obj.copy()
Expand Down
Loading
Loading