From c63401075789f4e3d2d2aefa9b752165d6b56308 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Mon, 11 Nov 2024 23:15:32 -0800 Subject: [PATCH] Minor change --- .github/tests/lm_tests.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 4b660bd7..333fa201 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -47,6 +47,7 @@ def print_usage_after_each_test(setup_models): print(f"\nUsage stats for {model_name} after test:") model.print_total_usage() model.reset_stats() + model.reset_cache() ################################################################################ @@ -315,20 +316,20 @@ def test_disable_cache(setup_models, model): lm.disable_cache() lotus.settings.configure(lm=lm) - first_batch = [ + batch = [ [{"role": "user", "content": "Hello, world!"}], [{"role": "user", "content": "What is the capital of France?"}], ] - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 0 - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 0 # Now enable cache. Note that the first batch is not cached. lm.enable_cache() - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 0 - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 2 @@ -338,23 +339,23 @@ def test_reset_cache(setup_models, model): lm.reset_cache() lotus.settings.configure(lm=lm) - first_batch = [ + batch = [ [{"role": "user", "content": "Hello, world!"}], [{"role": "user", "content": "What is the capital of France?"}], ] - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 0 - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 2 lm.reset_cache(max_size=1) - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 2 - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 3 lm.reset_cache(max_size=0) - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 3 - lm(first_batch) + lm(batch) assert lm.stats.total_usage.cache_hits == 3