From 1ec4d8a4a800011d21ef7d0e6483e235a4e3af87 Mon Sep 17 00:00:00 2001 From: Dhruv Iyer Date: Thu, 26 Dec 2024 15:28:30 -0700 Subject: [PATCH] Parallelizing sem_agg and sem_top_k (#66) Addressing #61 --- examples/op_examples/agg_with_grouping.py | 244 ++++++++++++++++++++ examples/op_examples/top_k_with_grouping.py | 60 +++++ lotus/sem_ops/sem_agg.py | 22 +- lotus/sem_ops/sem_topk.py | 47 ++-- lotus/settings.py | 8 + tests/test_settings.py | 25 ++ 6 files changed, 374 insertions(+), 32 deletions(-) create mode 100644 examples/op_examples/agg_with_grouping.py create mode 100644 examples/op_examples/top_k_with_grouping.py create mode 100644 tests/test_settings.py diff --git a/examples/op_examples/agg_with_grouping.py b/examples/op_examples/agg_with_grouping.py new file mode 100644 index 0000000..05f925b --- /dev/null +++ b/examples/op_examples/agg_with_grouping.py @@ -0,0 +1,244 @@ +import time + +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) + +data = { + "Course Name": [ + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + "Cooking", + "Food Sciences", + "Machine Learning", + "Data Structures and Algorithms", + "Quantum Mechanics", + "Organic Chemistry", + "Artificial Intelligence", + "Robotics", + "Thermodynamics", + "Fluid Mechanics", + "Molecular Biology", + "Genetics", + "Astrophysics", + "Neuroscience", + "Microeconomics", + "Macroeconomics", + "Linear Algebra", + "Calculus", + "Statistics", + "Differential Equations", + "Discrete Mathematics", + "Number Theory", + "Graph Theory", + "Topology", + "Complex Analysis", + "Real Analysis", + "Abstract Algebra", + "Numerical Methods", + "Cryptography", + "Network Security", + "Operating Systems", + "Databases", + "Computer Networks", + "Software Engineering", + "Compilers", + "Computer Architecture", + "Parallel Computing", + "Distributed Systems", + "Cloud Computing", + "Big Data Analytics", + "Natural Language Processing", + "Computer Vision", + "Reinforcement Learning", + "Deep Learning", + "Bioinformatics", + "Computational Biology", + "Systems Biology", + "Biochemistry", + "Physical Chemistry", + "Inorganic Chemistry", + "Analytical Chemistry", + "Environmental Chemistry", + "Materials Science", + "Nanotechnology", + "Optics", + "Electromagnetism", + "Nuclear Physics", + "Particle Physics", + "Cosmology", + "Planetary Science", + "Geophysics", + "Atmospheric Science", + "Oceanography", + "Ecology", + "Evolutionary Biology", + "Botany", + "Zoology", + "Microbiology", + "Immunology", + "Virology", + "Pharmacology", + "Physiology", + "Anatomy", + "Neurobiology", + "Cognitive Science", + "Psychology", + "Sociology", + "Anthropology", + "Archaeology", + "Linguistics", + "Philosophy", + "Ethics", + "Logic", + "Political Science", + "International Relations", + "Public Policy", + "Economics", + "Finance", + "Accounting", + "Marketing", + "Management", + "Entrepreneurship", + "Law", + "Criminal Justice", + "Human Rights", + "Environmental Studies", + "Sustainability", + "Urban Planning", + "Architecture", + "Civil Engineering", + "Mechanical Engineering", + "Electrical Engineering", + "Chemical Engineering", + "Aerospace Engineering", + "Biomedical Engineering", + "Environmental Engineering", + ], + "Grade Level": [ + "High School", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + ], +} + +df = pd.DataFrame(data) +start_time = time.time() +df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"]) +end_time = time.time() +print(df._output[0]) +print(f"Total execution time: {end_time - start_time:.2f} seconds") diff --git a/examples/op_examples/top_k_with_grouping.py b/examples/op_examples/top_k_with_grouping.py new file mode 100644 index 0000000..ab1cf06 --- /dev/null +++ b/examples/op_examples/top_k_with_grouping.py @@ -0,0 +1,60 @@ +import time + +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) + +data = { + "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, + "Course Name": [ + "Calculus", + "Quantum Mechanics", + "Data Structures", + "Genetics", + "Linear Algebra", + "Thermodynamics", + "Algorithms", + "Ecology", + "Statistics", + "Optics", + "Machine Learning", + "Molecular Biology", + "Number Theory", + "Relativity", + "Computer Networks", + "Evolutionary Biology", + "Differential Equations", + "Particle Physics", + "Operating Systems", + "Biochemistry", + "Complex Analysis", + "Fluid Dynamics", + "Artificial Intelligence", + "Microbiology", + "Topology", + "Astrophysics", + "Cybersecurity", + "Immunology", + ], +} + +df = pd.DataFrame(data) + +for method in ["quick", "heap", "naive"]: + start_time = time.time() + sorted_df, stats = df.sem_topk( + "Which {Course Name} is the most challenging?", + K=2, + method=method, + return_stats=True, + group_by=["Department"], + ) + end_time = time.time() + print(sorted_df) + print(stats) + print(f"Total execution time: {end_time - start_time:.2f} seconds") diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index dfb934b..7bb67ab 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -143,6 +143,11 @@ def __init__(self, pandas_obj: Any): def _validate(obj: Any) -> None: pass + @staticmethod + def process_group(args): + group, user_instruction, all_cols, suffix, progress_bar_desc = args + return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) + def __call__( self, user_instruction: str, @@ -181,19 +186,14 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"column {column} not found in DataFrame. Given usr instruction: {user_instruction}") - - - if group_by: grouped = self._obj.groupby(group_by) - new_df = pd.DataFrame() - for name, group in grouped: - res = group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) - new_df = pd.concat([new_df, res]) - return new_df - - - + group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: + return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args))) + # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: self._obj = self._obj.sort_values(by="_lotus_partition_id") diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 507d844..c92e81d 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -373,6 +373,19 @@ def __init__(self, pandas_obj: Any) -> None: def _validate(obj: Any) -> None: pass + @staticmethod + def process_group(args): + group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args + return group.sem_topk( + user_instruction, + K, + method=method, + strategy=strategy, + group_by=None, + cascade_threshold=cascade_threshold, + return_stats=return_stats, + ) + def __call__( self, user_instruction: str, @@ -416,30 +429,22 @@ def __call__( # Separate code path for grouping if group_by: grouped = self._obj.groupby(group_by) - new_df = pd.DataFrame() - stats = {} - for name, group in grouped: - res = group.sem_topk( - user_instruction, - K, - method=method, - strategy=strategy, - group_by=None, - cascade_threshold=cascade_threshold, - return_stats=return_stats, - ) - - if return_stats: - sorted_group, group_stats = res - stats[name] = group_stats - else: - sorted_group = res - - new_df = pd.concat([new_df, sorted_group]) + group_args = [ + (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) + for _, group in grouped + ] + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: + results = list(executor.map(SemTopKDataframe.process_group, group_args)) + if return_stats: + new_df = pd.concat([res[0] for res in results]) + stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} return new_df, stats - return new_df + else: + return pd.concat(results) if method == "quick-sem": assert len(col_li) == 1, "Only one column can be used for embedding optimization" diff --git a/lotus/settings.py b/lotus/settings.py index a39be43..ce12363 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,6 +1,8 @@ import lotus.models from lotus.types import SerializationFormat +# NOTE: Settings class is not thread-safe + class Settings: # Models @@ -15,11 +17,17 @@ class Settings: # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT + # Parallel groupby settings + parallel_groupby_max_threads: int = 8 + def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") setattr(self, key, value) + def __str__(self): + return str(vars(self)) + settings = Settings() diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..dc6f871 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,25 @@ +import pytest + +from lotus.settings import SerializationFormat, Settings + + +class TestSettings: + @pytest.fixture + def settings(self): + return Settings() + + def test_initial_values(self, settings): + assert settings.lm is None + assert settings.rm is None + assert settings.helper_lm is None + assert settings.reranker is None + assert settings.enable_cache is False + assert settings.serialization_format == SerializationFormat.DEFAULT + + def test_configure_method(self, settings): + settings.configure(enable_cache=True) + assert settings.enable_cache is True + + def test_invalid_setting(self, settings): + with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): + settings.configure(invalid_setting=True)