diff --git a/README.md b/README.md index 8eb9dd0..7765a61 100644 --- a/README.md +++ b/README.md @@ -1 +1,16 @@ # LibertAI Agents + +## Supported models + +We support multiple open-source models that have agentic capabilities. + +- [Hermes 2 Pro - Llama 3 8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) +- ⏳ [Hermes 3 - Llama-3.1 8B](https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B) +- ⏳ [Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) + +## Using a gated model + +Some models, like [Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) are gated ( +generally to require you to accept some usage conditions).\ +To use those models, you need to create an [access token](https://huggingface.co/settings/tokens) from your Hugging Face +account and give it to the `get_model` function. \ No newline at end of file diff --git a/libertai_agents/models/base.py b/libertai_agents/models/base.py index d96627e..e26efea 100644 --- a/libertai_agents/models/base.py +++ b/libertai_agents/models/base.py @@ -12,22 +12,35 @@ class Model(ABC): tokenizer: PreTrainedTokenizerFast vm_url: str + context_length: int system_message: bool - def __init__(self, model_id: str, vm_url: str, system_message: bool = True): + def __init__(self, model_id: str, vm_url: str, context_length: int, system_message: bool = True): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.vm_url = vm_url + self.context_length = context_length self.system_message = system_message + def __count_tokens(self, content: str) -> int: + tokens = self.tokenizer.tokenize(content) + return len(tokens) + def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str: - if self.system_message: - messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt)) + system_message = Message(role=MessageRoleEnum.system, content=system_prompt) raw_messages = list(map(lambda x: x.model_dump(), messages)) - return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False, - add_generation_prompt=True) + for i in range(len(raw_messages)): + included_messages: list = [system_message] + raw_messages[i:] + prompt = self.tokenizer.apply_chat_template(conversation=included_messages, tools=tools, + tokenize=False, + add_generation_prompt=True) + if not isinstance(prompt, str): + raise TypeError("Generated prompt isn't a string") + if self.__count_tokens(prompt) <= self.context_length: + return prompt + raise ValueError(f"Can't fit messages into the available context length ({self.context_length} tokens)") def generate_tool_call_id(self) -> str | None: return None diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py index 1500321..d14be38 100644 --- a/libertai_agents/models/hermes.py +++ b/libertai_agents/models/hermes.py @@ -6,10 +6,13 @@ class HermesModel(Model): - def __init__(self, model_id: str, vm_url: str): - super().__init__(model_id, vm_url) + def __init__(self, model_id: str, vm_url: str, context_length: int): + super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length) @staticmethod def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: - tool_calls = re.findall("\s*(.*)\s*", response) - return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + try: + tool_calls = re.findall(r'\s*(.*)\s*', response) + return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + except Exception: + return [] diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py index cddc814..e3e737e 100644 --- a/libertai_agents/models/mistral.py +++ b/libertai_agents/models/mistral.py @@ -7,8 +7,8 @@ class MistralModel(Model): - def __init__(self, model_id: str, vm_url: str): - super().__init__(model_id, vm_url, system_message=False) + def __init__(self, model_id: str, vm_url: str, context_length: int): + super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length, system_message=False) @staticmethod def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: diff --git a/libertai_agents/models/models.py b/libertai_agents/models/models.py index 66763a8..c92f6e2 100644 --- a/libertai_agents/models/models.py +++ b/libertai_agents/models/models.py @@ -10,6 +10,7 @@ class ModelConfiguration(BaseModel): vm_url: str + context_length: int constructor: typing.Type[Model] @@ -20,13 +21,17 @@ class ModelConfiguration(BaseModel): ] MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId)) +# TODO: update URLs with prod, and check context size (if we deploy it with a lower one) MODELS_CONFIG: dict[ModelId, ModelConfiguration] = { "NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration( vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion", + context_length=8192, constructor=HermesModel), "NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion", + context_length=131_072, constructor=HermesModel), "mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion", + context_length=131_072, constructor=MistralModel) } diff --git a/main.py b/main.py index 262df25..89d477b 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ async def start(): - agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"), + agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"), system_prompt="You are a helpful assistant", tools=[get_current_temperature]) response = await agent.generate_answer(