Skip to content

Commit

Permalink
single cache flag
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Jan 11, 2025
1 parent 445d3b6 commit a21e057
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
10 changes: 5 additions & 5 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_custom_tokenizer():
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
Expand Down Expand Up @@ -428,7 +428,7 @@ def test_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand All @@ -452,7 +452,7 @@ def test_disable_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand Down Expand Up @@ -482,7 +482,7 @@ def test_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_operator_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

data = {
"Course Name": [
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_disable_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_operator_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

data = {
"Course Name": [
Expand Down
4 changes: 2 additions & 2 deletions lotus/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def require_cache_enabled(func: Callable) -> Callable:

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_message_cache:
if not lotus.settings.enable_cache:
return None
return func(self, *args, **kwargs)

Expand All @@ -33,7 +33,7 @@ def operator_cache(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache
use_operator_cache = lotus.settings.enable_cache

if use_operator_cache and model.cache:

Expand Down
8 changes: 4 additions & 4 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __call__(
if all_kwargs.get("logprobs", False):
all_kwargs.setdefault("top_logprobs", 10)

if lotus.settings.enable_message_cache:
if lotus.settings.enable_cache:
# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]

uncached_data = (
[(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None]
if lotus.settings.enable_message_cache
if lotus.settings.enable_cache
else [(msg, "no-cache") for msg in messages]
)

Expand All @@ -72,7 +72,7 @@ def __call__(
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)
if lotus.settings.enable_message_cache:
if lotus.settings.enable_cache:
# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
if hash:
Expand All @@ -81,7 +81,7 @@ def __call__(
# Merge all responses in original order and extract outputs
all_responses = (
self._merge_responses(cached_responses, uncached_responses)
if lotus.settings.enable_message_cache
if lotus.settings.enable_cache
else uncached_responses
)
outputs = [self._get_top_choice(resp) for resp in all_responses]
Expand Down
3 changes: 1 addition & 2 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class Settings:
reranker: lotus.models.Reranker | None = None

# Cache settings
enable_message_cache: bool = False
enable_operator_cache: bool = False
enable_cache: bool = False

# Serialization setting
serialization_format: SerializationFormat = SerializationFormat.DEFAULT
Expand Down

0 comments on commit a21e057

Please sign in to comment.