diff --git a/README.md b/README.md
index 8eb9dd0..7765a61 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,16 @@
# LibertAI Agents
+
+## Supported models
+
+We support multiple open-source models that have agentic capabilities.
+
+- [Hermes 2 Pro - Llama 3 8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B)
+- ⏳ [Hermes 3 - Llama-3.1 8B](https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B)
+- ⏳ [Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)
+
+## Using a gated model
+
+Some models, like [Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) are gated (
+generally to require you to accept some usage conditions).\
+To use those models, you need to create an [access token](https://huggingface.co/settings/tokens) from your Hugging Face
+account and give it to the `get_model` function.
\ No newline at end of file
diff --git a/libertai_agents/models/base.py b/libertai_agents/models/base.py
index d96627e..e26efea 100644
--- a/libertai_agents/models/base.py
+++ b/libertai_agents/models/base.py
@@ -12,22 +12,35 @@ class Model(ABC):
tokenizer: PreTrainedTokenizerFast
vm_url: str
+ context_length: int
system_message: bool
- def __init__(self, model_id: str, vm_url: str, system_message: bool = True):
+ def __init__(self, model_id: str, vm_url: str, context_length: int, system_message: bool = True):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.vm_url = vm_url
+ self.context_length = context_length
self.system_message = system_message
+ def __count_tokens(self, content: str) -> int:
+ tokens = self.tokenizer.tokenize(content)
+ return len(tokens)
+
def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
- if self.system_message:
- messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
+ system_message = Message(role=MessageRoleEnum.system, content=system_prompt)
raw_messages = list(map(lambda x: x.model_dump(), messages))
- return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
- add_generation_prompt=True)
+ for i in range(len(raw_messages)):
+ included_messages: list = [system_message] + raw_messages[i:]
+ prompt = self.tokenizer.apply_chat_template(conversation=included_messages, tools=tools,
+ tokenize=False,
+ add_generation_prompt=True)
+ if not isinstance(prompt, str):
+ raise TypeError("Generated prompt isn't a string")
+ if self.__count_tokens(prompt) <= self.context_length:
+ return prompt
+ raise ValueError(f"Can't fit messages into the available context length ({self.context_length} tokens)")
def generate_tool_call_id(self) -> str | None:
return None
diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py
index 1500321..d14be38 100644
--- a/libertai_agents/models/hermes.py
+++ b/libertai_agents/models/hermes.py
@@ -6,10 +6,13 @@
class HermesModel(Model):
- def __init__(self, model_id: str, vm_url: str):
- super().__init__(model_id, vm_url)
+ def __init__(self, model_id: str, vm_url: str, context_length: int):
+ super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length)
@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
- tool_calls = re.findall("\s*(.*)\s*", response)
- return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
+ try:
+ tool_calls = re.findall(r'\s*(.*)\s*', response)
+ return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
+ except Exception:
+ return []
diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py
index cddc814..e3e737e 100644
--- a/libertai_agents/models/mistral.py
+++ b/libertai_agents/models/mistral.py
@@ -7,8 +7,8 @@
class MistralModel(Model):
- def __init__(self, model_id: str, vm_url: str):
- super().__init__(model_id, vm_url, system_message=False)
+ def __init__(self, model_id: str, vm_url: str, context_length: int):
+ super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length, system_message=False)
@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
diff --git a/libertai_agents/models/models.py b/libertai_agents/models/models.py
index 66763a8..c92f6e2 100644
--- a/libertai_agents/models/models.py
+++ b/libertai_agents/models/models.py
@@ -10,6 +10,7 @@
class ModelConfiguration(BaseModel):
vm_url: str
+ context_length: int
constructor: typing.Type[Model]
@@ -20,13 +21,17 @@ class ModelConfiguration(BaseModel):
]
MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))
+# TODO: update URLs with prod, and check context size (if we deploy it with a lower one)
MODELS_CONFIG: dict[ModelId, ModelConfiguration] = {
"NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
+ context_length=8192,
constructor=HermesModel),
"NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion",
+ context_length=131_072,
constructor=HermesModel),
"mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion",
+ context_length=131_072,
constructor=MistralModel)
}
diff --git a/main.py b/main.py
index 262df25..89d477b 100644
--- a/main.py
+++ b/main.py
@@ -7,7 +7,7 @@
async def start():
- agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"),
+ agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(