Skip to content

Commit

Permalink
sem_join +sem_join_cascade safe mode
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Dec 5, 2024
1 parent c693adf commit 22e2b36
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 19 deletions.
Binary file added examples/op_examples/Skill:right_index/index
Binary file not shown.
Binary file added examples/op_examples/Skill:right_index/vecs
Binary file not shown.
121 changes: 102 additions & 19 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from lotus.types import SemJoinCascadeArgs

lm = LM(model="gpt-4o-mini")
helper_lm = LM(model="gpt-3.5-turbo")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")

lotus.settings.configure(lm=lm, rm=rm)
lotus.settings.configure(lm=lm, rm=rm, helper_lm=helper_lm)
data = {
"Course Name": [
"Digital Design and Integrated Circuits",
Expand All @@ -18,22 +19,104 @@
}

skills = [
"Math", "Computer Science", "Management", "Creative Writing", "Data Analysis", "Machine Learning",
"Project Management", "Problem Solving", "Singing", "Critical Thinking", "Public Speaking", "Teamwork",
"Adaptability", "Programming", "Leadership", "Time Management", "Negotiation", "Decision Making", "Networking",
"Painting", "Customer Service", "Marketing", "Graphic Design", "Nursery", "SEO", "Content Creation",
"Video Editing", "Sales", "Financial Analysis", "Accounting", "Event Planning", "Foreign Languages",
"Software Development", "Cybersecurity", "Social Media Management", "Photography", "Writing & Editing",
"Technical Support", "Database Management", "Web Development", "Business Strategy", "Operations Management",
"UI/UX Design", "Reinforcement Learning", "Data Visualization", "Product Management", "Cloud Computing",
"Agile Methodology", "Blockchain", "IT Support", "Legal Research", "Supply Chain Management", "Copywriting",
"Human Resources", "Quality Assurance", "Medical Research", "Healthcare Management", "Sports Coaching",
"Editing & Proofreading", "Legal Writing", "Human Anatomy", "Chemistry", "Physics", "Biology", "Psychology",
"Sociology", "Anthropology", "Political Science", "Public Relations", "Fashion Design", "Interior Design",
"Automotive Repair", "Plumbing", "Carpentry", "Electrical Work", "Welding", "Electronics", "Hardware Engineering",
"Circuit Design", "Robotics", "Environmental Science", "Marine Biology", "Urban Planning", "Geography",
"Agricultural Science", "Animal Care", "Veterinary Science", "Zoology", "Ecology", "Botany", "Landscape Design",
"Baking & Pastry", "Culinary Arts", "Bartending", "Nutrition", "Dietary Planning", "Physical Training", "Yoga",
"Math",
"Computer Science",
"Management",
"Creative Writing",
"Data Analysis",
"Machine Learning",
"Project Management",
"Problem Solving",
"Singing",
"Critical Thinking",
"Public Speaking",
"Teamwork",
"Adaptability",
"Programming",
"Leadership",
"Time Management",
"Negotiation",
"Decision Making",
"Networking",
"Painting",
"Customer Service",
"Marketing",
"Graphic Design",
"Nursery",
"SEO",
"Content Creation",
"Video Editing",
"Sales",
"Financial Analysis",
"Accounting",
"Event Planning",
"Foreign Languages",
"Software Development",
"Cybersecurity",
"Social Media Management",
"Photography",
"Writing & Editing",
"Technical Support",
"Database Management",
"Web Development",
"Business Strategy",
"Operations Management",
"UI/UX Design",
"Reinforcement Learning",
"Data Visualization",
"Product Management",
"Cloud Computing",
"Agile Methodology",
"Blockchain",
"IT Support",
"Legal Research",
"Supply Chain Management",
"Copywriting",
"Human Resources",
"Quality Assurance",
"Medical Research",
"Healthcare Management",
"Sports Coaching",
"Editing & Proofreading",
"Legal Writing",
"Human Anatomy",
"Chemistry",
"Physics",
"Biology",
"Psychology",
"Sociology",
"Anthropology",
"Political Science",
"Public Relations",
"Fashion Design",
"Interior Design",
"Automotive Repair",
"Plumbing",
"Carpentry",
"Electrical Work",
"Welding",
"Electronics",
"Hardware Engineering",
"Circuit Design",
"Robotics",
"Environmental Science",
"Marine Biology",
"Urban Planning",
"Geography",
"Agricultural Science",
"Animal Care",
"Veterinary Science",
"Zoology",
"Ecology",
"Botany",
"Landscape Design",
"Baking & Pastry",
"Culinary Arts",
"Bartending",
"Nutrition",
"Dietary Planning",
"Physical Training",
"Yoga",
]
data2 = pd.DataFrame({"Skill": skills})

Expand All @@ -42,7 +125,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target = 0.7, precision_target = 0.7)
cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand All @@ -51,4 +134,4 @@
print(f" Helper resolved {stats['join_resolved_by_helper_model']} LM calls")
print(f"Join cascade used {stats['total_LM_calls']} LM calls in total")
print(f"Naive join would require {df1.shape[0]*df2.shape[0]} LM calls")
print(res)
print(res)
55 changes: 55 additions & 0 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lotus
from lotus.templates import task_instructions
from lotus.types import SemanticJoinOutput, SemJoinCascadeArgs
from lotus.utils import show_safe_mode

from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds
from .sem_filter import sem_filter
Expand All @@ -24,6 +25,7 @@ def sem_join(
cot_reasoning: list[str] | None = None,
default: bool = True,
strategy: str | None = None,
safe_mode: bool = False,
) -> SemanticJoinOutput:
"""
Joins two series using a model.
Expand Down Expand Up @@ -53,6 +55,20 @@ def sem_join(

left_multimodal_data = task_instructions.df2multimodal_info(l1.to_frame(col1_label), [col1_label])
right_multimodal_data = task_instructions.df2multimodal_info(l2.to_frame(col2_label), [col2_label])

sample_docs = task_instructions.merge_multimodal_info([left_multimodal_data[0]], right_multimodal_data)
estimated_tokens_per_call = model.count_tokens(
lotus.templates.task_instructions.filter_formatter(
sample_docs[0], user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy
)
)
estimated_total_calls = len(l1) * len(l2)
estimated_total_cost = estimated_tokens_per_call * estimated_total_calls
if safe_mode:
print("Sem_Join:")
show_safe_mode(estimated_total_cost, estimated_total_calls)
print("\n")

# for i1 in enumerate(l1):
for id1, i1 in zip(ids1, left_multimodal_data):
# perform llm filter
Expand Down Expand Up @@ -113,6 +129,7 @@ def sem_join_cascade(
cot_reasoning: list[str] | None = None,
default: bool = True,
strategy: str | None = None,
safe_mode: bool = False,
) -> SemanticJoinOutput:
"""
Joins two series using a cascade helper model and a large model.
Expand Down Expand Up @@ -182,6 +199,41 @@ def sem_join_cascade(
num_helper = len(helper_high_conf)
num_large = len(helper_low_conf)

if safe_mode:
sample_docs1 = task_instructions.df2multimodal_info(l1.to_frame(col1_label), [col1_label])
sample_docs2 = task_instructions.df2multimodal_info(l2.to_frame(col2_label), [col2_label])
sample_helper_doc = task_instructions.merge_multimodal_info([sample_docs1[0]], sample_docs2)

estimated_tokens_helper = lotus.settings.helper_lm.count_tokens(
lotus.templates.task_instructions.filter_formatter(
sample_helper_doc[0],
user_instruction,
examples_multimodal_data,
examples_answers,
cot_reasoning,
strategy,
)
)
estimated_tokens_large = lotus.settings.lm.count_tokens(
lotus.templates.task_instructions.filter_formatter(
sample_helper_doc[0],
user_instruction,
examples_multimodal_data,
examples_answers,
cot_reasoning,
strategy,
)
)

total_helper_tokens = estimated_tokens_helper * num_helper
total_large_tokens = estimated_tokens_large * num_large
total_tokens = total_helper_tokens + total_large_tokens + join_optimization_cost
total_lm_calls = num_helper + num_large

print("Sem_join_cascade:")
show_safe_mode(total_tokens, total_lm_calls)
print("\n")

# Accept helper results with high confidence
join_results = [(row["_left_id"], row["_right_id"], None) for _, row in helper_high_conf.iterrows()]

Expand Down Expand Up @@ -538,6 +590,7 @@ def __call__(
default: bool = True,
cascade_args: SemJoinCascadeArgs | None = None,
return_stats: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Applies semantic join over a dataframe.
Expand Down Expand Up @@ -647,6 +700,7 @@ def __call__(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
safe_mode=safe_mode,
)
else:
output = sem_join(
Expand All @@ -663,6 +717,7 @@ def __call__(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
safe_mode=safe_mode,
)
join_results = output.join_results
all_raw_outputs = output.all_raw_outputs
Expand Down

0 comments on commit 22e2b36

Please sign in to comment.