Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operator level cache #65

Merged
merged 22 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
42 changes: 42 additions & 0 deletions lotus/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import pickle
import sqlite3
Expand All @@ -8,6 +10,8 @@
from functools import wraps
from typing import Any, Callable

import pandas as pd

import lotus


Expand All @@ -23,6 +27,44 @@ def wrapper(self, *args, **kwargs):
return wrapper


def operator_cache(func: Callable) -> Callable:
"""Decorator to add operator level caching."""

@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
elif hasattr(value, "dict"):
return value.dict()
return value

serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
if cached_result is not None:
print(f"Cache hit for {cache_key}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lotus.logger rather than prints.

return cached_result
print(f"Cache miss for {cache_key}")

result = func(self, *args, **kwargs)
model.cache.insert(cache_key, result)
return result

return func(self, *args, **kwargs)

return wrapper


class CacheType(Enum):
IN_MEMORY = "in_memory"
SQLITE = "sqlite"
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus.models
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput

Expand Down Expand Up @@ -148,6 +149,7 @@ def process_group(args):
group, user_instruction, all_cols, suffix, progress_bar_desc = args
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
4 changes: 3 additions & 1 deletion lotus/sem_ops/sem_cluster_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache


@pd.api.extensions.register_dataframe_accessor("sem_cluster_by")
Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down Expand Up @@ -52,7 +54,7 @@ def __call__(
self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index)
# if return_scores:
# self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index)

# if return_centroids:
# return self._obj, centroids
# else:
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
Expand Down Expand Up @@ -33,7 +34,6 @@ def sem_extract(
Returns:
SemanticExtractOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
Expand Down Expand Up @@ -72,6 +72,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
input_cols: list[str],
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.typing import NDArray

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -134,6 +135,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, SemanticJoinOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -234,7 +235,6 @@ def sem_join_cascade(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_progress_bar=False,
)
pbar.update(num_large)
pbar.close()
Expand Down Expand Up @@ -545,6 +545,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
other: pd.DataFrame | pd.Series,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -80,6 +81,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.types import RerankerOutput, RMOutput


Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_sim_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.models import RM
from lotus.types import RMOutput

Expand All @@ -20,6 +21,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
other: pd.DataFrame,
Expand Down
4 changes: 3 additions & 1 deletion lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticTopKOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -386,6 +387,7 @@ def process_group(args):
return_stats=return_stats,
)

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down Expand Up @@ -438,7 +440,7 @@ def __call__(

with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor:
results = list(executor.map(SemTopKDataframe.process_group, group_args))

if return_stats:
new_df = pd.concat([res[0] for res in results])
stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)}
Expand Down
1 change: 1 addition & 0 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Settings:

# Cache settings
enable_cache: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can enable_cache be renamed to enable_messsage_cache so its clear what the distinction is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I can do that

enable_operator_cache: bool = False

# Serialization setting
serialization_format: SerializationFormat = SerializationFormat.DEFAULT
Expand Down
Loading