diff --git a/libertai_agents/libertai_agents/models/base.py b/libertai_agents/libertai_agents/models/base.py index 29c0c94..668467c 100644 --- a/libertai_agents/libertai_agents/models/base.py +++ b/libertai_agents/libertai_agents/models/base.py @@ -76,9 +76,25 @@ def generate_prompt( ) raw_messages = [x.model_dump() for x in messages] - for i in range(len(raw_messages)): + prompt = ( + self.tokenizer.apply_chat_template( + conversation=system_messages, + tools=[x.args_schema for x in tools], + tokenize=False, + add_generation_prompt=True, + ) + if len(system_messages) != 0 + else "" + ) + if self.__count_tokens(prompt) > self.context_length: + raise ValueError( + f"Can't fit system messages into the available context length ({self.context_length} tokens)" + ) + + # Adding as many messages as we can fit into the context, starting from the last ones + for i in reversed(range(len(raw_messages))): included_messages: list = system_messages + raw_messages[i:] - prompt = self.tokenizer.apply_chat_template( + new_prompt = self.tokenizer.apply_chat_template( conversation=included_messages, tools=[x.args_schema for x in tools], tokenize=False, @@ -86,11 +102,11 @@ def generate_prompt( ) if not isinstance(prompt, str): raise TypeError("Generated prompt isn't a string") - if self.__count_tokens(prompt) <= self.context_length: + if self.__count_tokens(new_prompt) >= self.context_length: return prompt - raise ValueError( - f"Can't fit messages into the available context length ({self.context_length} tokens)" - ) + prompt = new_prompt + + return prompt def generate_tool_call_id(self) -> str | None: """