Skip to content

Commit

Permalink
feat: Support context length and generate prompt based on it
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 25, 2024
1 parent 996cdbe commit c4a5b1e
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 12 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
# LibertAI Agents

## Supported models

We support multiple open-source models that have agentic capabilities.

- [Hermes 2 Pro - Llama 3 8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B)
-[Hermes 3 - Llama-3.1 8B](https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B)
-[Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)

## Using a gated model

Some models, like [Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) are gated (
generally to require you to accept some usage conditions).\
To use those models, you need to create an [access token](https://huggingface.co/settings/tokens) from your Hugging Face
account and give it to the `get_model` function.
23 changes: 18 additions & 5 deletions libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,35 @@ class Model(ABC):

tokenizer: PreTrainedTokenizerFast
vm_url: str
context_length: int
system_message: bool

def __init__(self, model_id: str, vm_url: str, system_message: bool = True):
def __init__(self, model_id: str, vm_url: str, context_length: int, system_message: bool = True):
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.vm_url = vm_url
self.context_length = context_length
self.system_message = system_message

def __count_tokens(self, content: str) -> int:
tokens = self.tokenizer.tokenize(content)
return len(tokens)

def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
if self.system_message:
messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
system_message = Message(role=MessageRoleEnum.system, content=system_prompt)
raw_messages = list(map(lambda x: x.model_dump(), messages))

return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)
for i in range(len(raw_messages)):
included_messages: list = [system_message] + raw_messages[i:]
prompt = self.tokenizer.apply_chat_template(conversation=included_messages, tools=tools,
tokenize=False,
add_generation_prompt=True)
if not isinstance(prompt, str):
raise TypeError("Generated prompt isn't a string")
if self.__count_tokens(prompt) <= self.context_length:
return prompt
raise ValueError(f"Can't fit messages into the available context length ({self.context_length} tokens)")

def generate_tool_call_id(self) -> str | None:
return None
Expand Down
11 changes: 7 additions & 4 deletions libertai_agents/models/hermes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@


class HermesModel(Model):
def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)
def __init__(self, model_id: str, vm_url: str, context_length: int):
super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("<tool_call>\s*(.*)\s*</tool_call>", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
try:
tool_calls = re.findall(r'<tool_call>\s*(.*)\s*</tool_call>', response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
except Exception:
return []
4 changes: 2 additions & 2 deletions libertai_agents/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class MistralModel(Model):
def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url, system_message=False)
def __init__(self, model_id: str, vm_url: str, context_length: int):
super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length, system_message=False)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
Expand Down
5 changes: 5 additions & 0 deletions libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class ModelConfiguration(BaseModel):
vm_url: str
context_length: int
constructor: typing.Type[Model]


Expand All @@ -20,13 +21,17 @@ class ModelConfiguration(BaseModel):
]
MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))

# TODO: update URLs with prod, and check context size (if we deploy it with a lower one)
MODELS_CONFIG: dict[ModelId, ModelConfiguration] = {
"NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
context_length=8192,
constructor=HermesModel),
"NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion",
context_length=131_072,
constructor=HermesModel),
"mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion",
context_length=131_072,
constructor=MistralModel)
}

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


async def start():
agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"),
agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(
Expand Down

0 comments on commit c4a5b1e

Please sign in to comment.