From cb22635bffe294e9a2a7c5ffef14949c49f31586 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:38:47 -0500 Subject: [PATCH] Fix memory issue with llama.cpp LLM pipeline, closes #824 --- src/python/txtai/pipeline/llm/llama.py | 40 +++++++++++++++--- .../testpipeline/testaudio/testmicrophone.py | 2 +- test/python/testpipeline/testllm/testllama.py | 42 +++++++++++++++++++ 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/src/python/txtai/pipeline/llm/llama.py b/src/python/txtai/pipeline/llm/llama.py index 841958a8e..2328fd0a7 100644 --- a/src/python/txtai/pipeline/llm/llama.py +++ b/src/python/txtai/pipeline/llm/llama.py @@ -8,7 +8,7 @@ # Conditional import try: - from llama_cpp import Llama + import llama_cpp as llama LLAMA_CPP = True except ImportError: @@ -45,11 +45,8 @@ def __init__(self, path, template=None, **kwargs): # Check if this is a local path, otherwise download from the HF Hub path = path if os.path.exists(path) else self.download(path) - # Default GPU layers if not already set - kwargs["n_gpu_layers"] = kwargs.get("n_gpu_layers", -1 if kwargs.get("gpu", os.environ.get("LLAMA_NO_METAL") != "1") else 0) - # Create llama.cpp instance - self.llm = Llama(path, n_ctx=0, verbose=kwargs.pop("verbose", False), **kwargs) + self.llm = self.create(path, **kwargs) def stream(self, texts, maxlength, stream, stop, **kwargs): for text in texts: @@ -79,6 +76,39 @@ def download(self, path): # Download and cache file return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:])) + def create(self, path, **kwargs): + """ + Creates a new llama.cpp model instance. + + Args: + path: path to model + kwargs: additional keyword args + + Returns: + llama.cpp instance + """ + + # Default n_ctx=0 if not already set. This sets n_ctx = n_ctx_train. + kwargs["n_ctx"] = kwargs.get("n_ctx", 0) + + # Default GPU layers if not already set + kwargs["n_gpu_layers"] = kwargs.get("n_gpu_layers", -1 if kwargs.get("gpu", os.environ.get("LLAMA_NO_METAL") != "1") else 0) + + # Default verbose flag + kwargs["verbose"] = kwargs.get("verbose", False) + + # Create llama.cpp instance + try: + return llama.Llama(model_path=path, **kwargs) + except ValueError as e: + # Fallback to default n_ctx when not enough memory for n_ctx = n_ctx_train + if not kwargs["n_ctx"]: + kwargs.pop("n_ctx") + return llama.Llama(model_path=path, **kwargs) + + # Raise exception if n_ctx manually specified + raise e + def messages(self, messages, maxlength, stream, stop, **kwargs): """ Processes a list of messages. diff --git a/test/python/testpipeline/testaudio/testmicrophone.py b/test/python/testpipeline/testaudio/testmicrophone.py index 2b7a9bcfc..baa94baeb 100644 --- a/test/python/testpipeline/testaudio/testmicrophone.py +++ b/test/python/testpipeline/testaudio/testmicrophone.py @@ -70,7 +70,7 @@ def int16(self, data): return (data * absmax + offset).clip(i.min, i.max).astype(np.int16) # Mock input stream - inputstream.return_value = RawInputStream() + inputstream.side_effect = RawInputStream # Create microphone pipeline and read data pipeline = Microphone() diff --git a/test/python/testpipeline/testllm/testllama.py b/test/python/testpipeline/testllm/testllama.py index f7c0e10d2..2dc4d007a 100644 --- a/test/python/testpipeline/testllm/testllama.py +++ b/test/python/testpipeline/testllm/testllama.py @@ -4,6 +4,8 @@ import unittest +from unittest.mock import patch + from txtai.pipeline import LLM @@ -12,6 +14,46 @@ class TestLlama(unittest.TestCase): llama.cpp tests. """ + @patch("llama_cpp.Llama") + def testContext(self, llama): + """ + Test n_ctx with llama.cpp + """ + + class Llama: + """ + Mock llama.cpp instance to test invalid context + """ + + def __init__(self, **kwargs): + if kwargs.get("n_ctx") == 0 or kwargs.get("n_ctx", 0) >= 10000: + raise ValueError("Failed to create context") + + # Save parameters + self.params = kwargs + + # Mock llama.cpp instance + llama.side_effect = Llama + + # Model to test + path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + + # Test omitting n_ctx falls back to default settings + llm = LLM(path) + self.assertNotIn("n_ctx", llm.generator.llm.params) + + # Test n_ctx=0 falls back to default settings + llm = LLM(path, n_ctx=0) + self.assertNotIn("n_ctx", llm.generator.llm.params) + + # Test n_ctx manually set + llm = LLM(path, n_ctx=1024) + self.assertEqual(llm.generator.llm.params["n_ctx"], 1024) + + # Mock a value for n_ctx that's too big + with self.assertRaises(ValueError): + llm = LLM(path, n_ctx=10000) + def testGeneration(self): """ Test generation with llama.cpp