Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support running multiple bash commands and compile in one query #736

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logger
import utils
from llm_toolkit.models import LLM
from llm_toolkit.prompt_builder import DefaultTemplateBuilder
from llm_toolkit.prompts import Prompt
from results import Result
from tool.base_tool import BaseTool
Expand Down Expand Up @@ -63,6 +62,11 @@ def _parse_tag(self, response: str, tag: str) -> str:
match = re.search(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
return match.group(1).strip() if match else ''

def _parse_tags(self, response: str, tag: str) -> list[str]:
"""Parses the XML-style tags from LLM response."""
matches = re.findall(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
return [content.strip() for content in matches]

def _filter_code(self, raw_code_block: str) -> str:
"""Filters out irrelevant lines from |raw_code_block|."""
# TODO(dongge): Move this function to a separate module.
Expand All @@ -74,27 +78,56 @@ def _filter_code(self, raw_code_block: str) -> str:
filtered_code_block = '\n'.join(filtered_lines)
return filtered_code_block

def _format_bash_execution_result(self, process: sp.CompletedProcess) -> str:
def _format_bash_execution_result(
self,
process: sp.CompletedProcess,
previous_prompt: Optional[Prompt] = None) -> str:
"""Formats a prompt based on bash execution result."""
stdout = self.llm.truncate_prompt(process.stdout)
# TODO(dongge) Share input limit evenly if both stdout and stderr overlong.
stderr = self.llm.truncate_prompt(process.stderr, stdout)
if previous_prompt:
previous_prompt_text = previous_prompt.get()
else:
previous_prompt_text = ''
stdout = self.llm.truncate_prompt(process.stdout,
previous_prompt_text).strip()
stderr = self.llm.truncate_prompt(process.stderr,
stdout + previous_prompt_text).strip()
return (f'<bash>\n{process.args}\n</bash>\n'
f'<return code>\n{process.returncode}\n</return code>\n'
f'<stdout>\n{stdout}\n</stdout>\n'
f'<stderr>\n{stderr}\n</stderr>\n')

def _container_handle_bash_command(self, command: str,
tool: BaseTool) -> Prompt:
def _container_handle_bash_command(self, response: str, tool: BaseTool,
prompt: Prompt) -> Prompt:
"""Handles the command from LLM with container |tool|."""
prompt_text = self._format_bash_execution_result(tool.execute(command))
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
prompt_text = ''
for command in self._parse_tags(response, 'bash'):
prompt_text += self._format_bash_execution_result(
tool.execute(command), previous_prompt=prompt) + '\n'
prompt.append(prompt_text)
return prompt

def _container_handle_invalid_tool_usage(self, tool: BaseTool, cur_round: int,
response: str,
prompt: Prompt) -> Prompt:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=self.trial)
prompt_text = (f'No valid instruction received, Please follow the '
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
prompt.append(prompt_text)
return prompt

def _container_handle_bash_commands(self, response: str, tool: BaseTool,
prompt: Prompt) -> Prompt:
"""Handles the command from LLM with container |tool|."""
prompt_text = ''
for command in self._parse_tags(response, 'bash'):
prompt_text += self._format_bash_execution_result(
tool.execute(command), previous_prompt=prompt) + '\n'
prompt.append(prompt_text)
return prompt

def _sleep_random_duration(
self,
Expand Down
41 changes: 24 additions & 17 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ def _validate_fuzz_target_and_build_script_via_compile(
status=compile_succeed and binary_exists,
referenced=function_referenced)

def _container_handle_conclusion(
self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
def _container_handle_conclusion(self, cur_round: int, response: str,
build_result: BuildResult,
prompt: Prompt) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
if not self._parse_tag(response, 'fuzz target'):
return prompt
logger.info('----- ROUND %02d Received conclusion -----',
cur_round,
trial=build_result.trial)
Expand All @@ -198,7 +200,8 @@ def _container_handle_conclusion(
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
compile_log = self.llm.truncate_prompt(build_result.compile_log,
extra_text=prompt.get()).strip()
logger.info('***** Failed to recompile in %02d rounds *****',
cur_round,
trial=build_result.trial)
Expand Down Expand Up @@ -243,25 +246,29 @@ def _container_handle_conclusion(
else:
prompt_text = ''

prompt = DefaultTemplateBuilder(self.llm, initial=prompt_text).build([])
prompt.append(prompt_text)
return prompt

def _container_tool_reaction(self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
"""Validates LLM conclusion or executes its command."""
# Prioritize Bash instructions.
if command := self._parse_tag(response, 'bash'):
return self._container_handle_bash_command(command, self.inspect_tool)
prompt = DefaultTemplateBuilder(self.llm, None).build([])
prompt = self._container_handle_bash_commands(response, self.inspect_tool,
prompt)

if self._parse_tag(response, 'conclusion'):
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
# Then build fuzz target.
prompt = self._container_handle_conclusion(cur_round, response,
build_result, prompt)
if prompt is None:
# Succeeded.
return None

# Finally check invalid responses.
if not prompt.get():
prompt = self._container_handle_invalid_tool_usage(
self.inspect_tool, cur_round, response, prompt)

return prompt

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
Expand Down
13 changes: 10 additions & 3 deletions llm_toolkit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import openai
import tiktoken
import vertexai
from google.api_core.exceptions import (GoogleAPICallError, InvalidArgument,
from google.api_core.exceptions import (GoogleAPICallError,
InternalServerError, InvalidArgument,
ResourceExhausted, ServiceUnavailable,
TooManyRequests)
from vertexai import generative_models
Expand Down Expand Up @@ -649,6 +650,7 @@ def get_chat_client(self, model: GenerativeModel) -> Any:
InvalidArgument,
ValueError, # TODO(dongge): Handle RECITATION specifically.
IndexError, # A known error from vertexai.
InternalServerError,
],
other_exceptions={
ResourceExhausted: 100,
Expand Down Expand Up @@ -680,10 +682,15 @@ def truncate_prompt(self,
max_raw_prompt_token_size = (self.MAX_INPUT_TOKEN - extra_text_token_count -
10000)

while token_count > max_raw_prompt_token_size:
while token_count > max_raw_prompt_token_size // 4:
estimate_truncate_size = int(
(1 - max_raw_prompt_token_size / token_count) * len(raw_prompt_text))
raw_prompt_text = raw_prompt_text[estimate_truncate_size + 1:]

logger.warning('estimate_truncate_size: %s', estimate_truncate_size)
raw_prompt_init = raw_prompt_text[:100] + (
'\n...(truncated due to exceeding input token limit)...\n')
raw_prompt_text = raw_prompt_init + raw_prompt_text[
100 + estimate_truncate_size + 1:]

token_count = self.estimate_token_num(raw_prompt_text)
logger.warning('Truncated raw prompt from %d to %d tokens:',
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def deserialize_from_dill(dill_path: Any) -> Any:
def _default_retry_delay_fn(e: Exception, n: int):
"""Delays retry by a random seconds between 0 to 1 minute."""
del e, n
return random.uniform(0, 60)
random.uniform(60, 120)


def retryable(exceptions=None,
Expand Down