Skip to content

Commit

Permalink
Allow passing system instructions to LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
DonggeLiu committed Nov 13, 2024
1 parent d5c81ad commit ca34bff
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 9 additions & 6 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
from experiment.benchmark import Benchmark
from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS
from llm_toolkit.prompt_builder import (DefaultTemplateBuilder,
PrototyperTemplateBuilder)
from llm_toolkit.prompts import Prompt
Expand All @@ -34,14 +33,12 @@ def _initial_prompt(self, results: list[Result]) -> Prompt:
)
prompt = prompt_builder.build(example_pair=[],
project_context_content=context_info)
self.llm._system_instruction = prompt_builder.system_instructions(
self.llm.system_instruction = prompt_builder.system_instructions(
benchmark, [
'prototyper-system-instruction-objective.txt',
'prototyper-system-instruction-protocols.txt'
])
# prompt = prompt_builder.build(example_pair=EXAMPLE_FUZZ_TARGETS.get(
# benchmark.language, []),
# tool_guides=self.inspect_tool.tutorial())
self.protocol = self.llm.system_instruction[-1]
return prompt

def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
Expand Down Expand Up @@ -222,6 +219,12 @@ def _container_handle_conclusion(
prompt = DefaultTemplateBuilder(self.llm, initial=prompt_text).build([])
return prompt

def _container_handle_invalid_tool_usage(self) -> Prompt:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
prompt_text = (f'No valid instruction received, Please follow the system '
f'instructions:\n{self.protocol}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _container_tool_reaction(self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
"""Validates LLM conclusion or executes its command."""
Expand All @@ -235,7 +238,7 @@ def _container_tool_reaction(self, cur_round: int, response: str,
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
response)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
return self._container_handle_invalid_tool_usage()

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
Expand Down
4 changes: 3 additions & 1 deletion llm_toolkit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class LLM:
MAX_INPUT_TOKEN: int = sys.maxsize

_max_attempts = 5 # Maximum number of attempts to get prediction response
system_instruction: Optional[list] = None

def __init__(
self,
Expand Down Expand Up @@ -564,7 +565,8 @@ class GeminiModel(VertexAIModel):
]

def get_model(self) -> Any:
return GenerativeModel(self._vertex_ai_model)
return GenerativeModel(self._vertex_ai_model,
system_instruction=self.system_instruction)

def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any:
# Loosen inapplicable restrictions just in case.
Expand Down

0 comments on commit ca34bff

Please sign in to comment.