Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tooled Prototyper #731

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The abstract base class for LLM agents in stages."""
import argparse
import codecs
import random
import re
import subprocess as sp
Expand All @@ -10,7 +11,6 @@
import logger
import utils
from llm_toolkit.models import LLM
from llm_toolkit.prompt_builder import DefaultTemplateBuilder
from llm_toolkit.prompts import Prompt
from results import Result
from tool.base_tool import BaseTool
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]:
return None

def chat_llm(self, cur_round: int, client: Any, prompt: Prompt,
trial: int) -> str:
trial: int) -> Any:
"""Chat with LLM."""
logger.info('<CHAT PROMPT:ROUND %02d>%s</CHAT PROMPT:ROUND %02d>',
cur_round,
Expand All @@ -67,34 +67,40 @@ def _filter_code(self, raw_code_block: str) -> str:
"""Filters out irrelevant lines from |raw_code_block|."""
# TODO(dongge): Move this function to a separate module.
# Remove markdown-style code block symbols.

raw_code_block = codecs.decode(raw_code_block, 'unicode_escape').strip()

filtered_lines = [
line for line in raw_code_block.splitlines()
if not line.strip().startswith('```')
]
filtered_code_block = '\n'.join(filtered_lines)
return filtered_code_block

def _format_bash_execution_result(self, process: sp.CompletedProcess) -> str:
def _format_bash_execution_result(self, process: sp.CompletedProcess) -> dict:
"""Formats a prompt based on bash execution result."""
stdout = self.llm.truncate_prompt(process.stdout)
# TODO(dongge) Share input limit evenly if both stdout and stderr overlong.
stderr = self.llm.truncate_prompt(process.stderr, stdout)
return (f'<bash>\n{process.args}\n</bash>\n'
f'<return code>\n{process.returncode}\n</return code>\n'
f'<stdout>\n{stdout}\n</stdout>\n'
f'<stderr>\n{stderr}\n</stderr>\n')

def _container_handle_bash_command(self, command: str,
tool: BaseTool) -> Prompt:
return {
'command': process.args,
'returncode': process.returncode,
'stdout': stdout,
'stderr': stderr,
}

def _container_handle_bash_command(self, args: dict, tool: BaseTool) -> dict:
"""Handles the command from LLM with container |tool|."""
prompt_text = self._format_bash_execution_result(tool.execute(command))
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
return args | self._format_bash_execution_result(
tool.execute(self._filter_code(args.get('command', ''))))

def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
def _container_handle_invalid_tool_usage(self, args: dict,
tool: BaseTool) -> dict:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
prompt_text = (f'No valid instruction received, Please follow the '
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
return args | {
'error': ('Malformatted function call, function name is not in '
f'{[decl.name for decl in tool.declarations()]}')
}

def _sleep_random_duration(
self,
Expand Down
186 changes: 114 additions & 72 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from datetime import timedelta
from typing import Optional

from vertexai.preview.generative_models import (ChatSession, FunctionCall,
GenerationResponse, Part, Tool,
ToolConfig)

import logger
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
Expand All @@ -32,20 +36,31 @@ def _initial_prompt(self, results: list[Result]) -> Prompt:
benchmark=benchmark,
)
prompt = prompt_builder.build(example_pair=[],
project_context_content=context_info,
tool_guides=self.inspect_tool.tutorial())
project_context_content=context_info)
# prompt = prompt_builder.build(example_pair=EXAMPLE_FUZZ_TARGETS.get(
# benchmark.language, []),
# tool_guides=self.inspect_tool.tutorial())
return prompt

def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
def _initialize_chat_session(self) -> ChatSession:
"""Initializes the LLM chat session with |tools|"""
self.llm.tools = [
Tool(function_declarations=self.inspect_tool.declarations())
]
self.llm.tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY))
model = self.llm.get_model()
return self.llm.get_chat_client(model=model)

def _update_fuzz_target_and_build_script(self, cur_round: int, args: dict,
build_result: BuildResult) -> None:
"""Updates fuzz target and build script in build_result with LLM response.
"""
fuzz_target_source = self._filter_code(
self._parse_tag(response, 'fuzz target'))
fuzz_target_source = self._filter_code(args.get('fuzz_target', ''))
build_result.fuzz_target_source = fuzz_target_source

args['fuzz_target'] = fuzz_target_source
if fuzz_target_source:
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s',
cur_round,
Expand All @@ -54,14 +69,11 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
else:
logger.error('ROUND %02d No fuzz target source code in conclusion: %s',
cur_round,
response,
args,
trial=build_result.trial)

build_script_source = self._filter_code(
self._parse_tag(response, 'build script'))
# Sometimes LLM adds chronos, which makes no sense for new build scripts.
build_result.build_script_source = build_script_source.replace(
'source /src/chronos.sh', '')
build_script_source = self._filter_code(args.get('build_script', ''))
args['build_script'] = build_script_source
if build_script_source:
logger.debug('ROUND %02d Parsed build script from LLM: %s',
cur_round,
Expand All @@ -70,17 +82,19 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
else:
logger.debug('ROUND %02d No build script in conclusion: %s',
cur_round,
response,
args,
trial=build_result.trial)

def _update_build_result(self, build_result: BuildResult,
compile_process: sp.CompletedProcess, status: bool,
referenced: bool) -> None:
"""Updates the build result with the latest info."""
build_result.compiles = status
build_result.compile_error = compile_process.stderr
build_result.compile_log = self._format_bash_execution_result(
compile_process)

compile_result = self._format_bash_execution_result(compile_process)
build_result.compile_stdout = compile_result.get('stdout', '')
build_result.compile_stderr = compile_result.get('stderr', '')
build_result.compile_log = str(compile_result)
build_result.is_function_referenced = referenced

def _validate_fuzz_target_and_build_script(self, cur_round: int,
Expand Down Expand Up @@ -179,16 +193,15 @@ def _validate_fuzz_target_and_build_script_via_compile(
status=compile_succeed and binary_exists,
referenced=function_referenced)

def _container_handle_conclusion(
self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
def _container_handle_conclusion(self, cur_round: int, args: dict,
build_result: BuildResult) -> Optional[dict]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
logger.info('----- ROUND %02d Received conclusion -----',
cur_round,
trial=build_result.trial)

self._update_fuzz_target_and_build_script(cur_round, response, build_result)
self._update_fuzz_target_and_build_script(cur_round, args, build_result)

self._validate_fuzz_target_and_build_script(cur_round, build_result)
if build_result.success:
Expand All @@ -198,70 +211,99 @@ def _container_handle_conclusion(
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
logger.info('***** Failed to recompile in %02d rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'Failed to build fuzz target. Here is the fuzz target, build script, '
'compliation command, and other compilation runtime output. Analyze '
'the error messages, the fuzz target, and the build script carefully '
'to identify the root cause. Avoid making random changes to the fuzz '
'target or build script without a clear understanding of the error. '
'If necessary, #include necessary headers and #define required macros'
'or constants in the fuzz target, or adjust compiler flags to link '
'required libraries in the build script. After collecting information'
', analyzing and understanding the error root cause, YOU MUST take at'
' least one step to validate your theory with source code evidence. '
'Only if your theory is verified, respond the revised fuzz target and'
'build script in FULL.\n'
'Always try to learn from the source code about how to fix errors, '
'for example, search for the key words (e.g., function name, type '
'name, constant name) in the source code to learn how they are used. '
'Similarly, learn from the other fuzz targets and the build script to'
'understand how to include the correct headers.\n'
'Focus on writing a minimum buildable fuzz target that calls the '
'target function. We can increase its complexity later, but first try'
'to make it compile successfully.'
'If an error happens repeatedly and cannot be fixed, try to '
'mitigate it. For example, replace or remove the line.'
f'<fuzz target>\n{build_result.fuzz_target_source}\n</fuzz target>\n'
f'<build script>\n{build_result.build_script_source}\n</build script>'
f'\n<compilation log>\n{compile_log}\n</compilation log>\n')
elif not build_result.is_function_referenced:
content = {
'compilation_result':
str(build_result.compiles),
'stdout':
build_result.compile_stdout,
'stderr':
build_result.compile_stderr,
'fix_guide': (
'Failed to build fuzz target. Analyze the stdout, stderr, the fuzz '
'target, and the build script carefully '
'to identify the root cause. Avoid making random changes to the fuzz '
'target or build script without a clear understanding of the error. '
'If necessary, #include necessary headers and #define required macros'
'or constants in the fuzz target, or adjust compiler flags to link '
'required libraries in the build script. After collecting information'
', analyzing and understanding the error root cause, YOU MUST take at'
' least one step to validate your theory with source code evidence. '
'Only if your theory is verified, respond the revised fuzz target and'
'build script in FULL.\n'
'Always try to learn from the source code about how to fix errors, '
'for example, search for the key words (e.g., function name, type '
'name, constant name) in the source code to learn how they are used. '
'Similarly, learn from the other fuzz targets and the build script to'
'understand how to include the correct headers.\n'
'Focus on writing a minimum buildable fuzz target that calls the '
'target function. We can increase its complexity later, but first try'
'to make it compile successfully.'
'If an error happens repeatedly and cannot be fixed, try to '
'mitigate it. For example, replace or remove the line.')
}
elif not build_result.is_function_referenced and not args.get(
'referenced', False):
logger.info(
'***** Fuzz target does not reference function-under-test in %02d '
'rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'The fuzz target builds successfully, but the target function '
f'`{build_result.benchmark.function_signature}` was not used by '
'`LLVMFuzzerTestOneInput` in fuzz target. YOU MUST CALL FUNCTION '
f'`{build_result.benchmark.function_signature}` INSIDE FUNCTION '
'`LLVMFuzzerTestOneInput`.')
content = {
'compilation_result':
str(build_result.compiles),
'stdout':
build_result.compile_stdout,
'stderr':
build_result.compile_stderr,
'fix_guide': (
'The fuzz target builds successfully, but the target function '
f'`{build_result.benchmark.function_signature}` was not used by '
'`LLVMFuzzerTestOneInput` in fuzz target. YOU MUST CALL FUNCTION '
f'`{build_result.benchmark.function_signature}` INSIDE FUNCTION '
'`LLVMFuzzerTestOneInput`.')
}
else:
prompt_text = ''
content = {
'compilation_result': str(build_result.compiles),
'stdout': build_result.compile_stdout,
'stderr': build_result.compile_stderr,
'fix_guide': 'Unknown compilation failure'
}

prompt = DefaultTemplateBuilder(self.llm, initial=prompt_text).build([])
return prompt
return args | content

def _container_tool_reaction(self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
"""Validates LLM conclusion or executes its command."""
# Prioritize Bash instructions.
if command := self._parse_tag(response, 'bash'):
return self._container_handle_bash_command(command, self.inspect_tool)
def _call_function(self, cur_round: int, call: FunctionCall,
build_result: BuildResult) -> Optional[Part]:
"""Calls tool functions based on LLM response."""
if call.name == 'bash':
content = self._container_handle_bash_command(call.args,
self.inspect_tool)
elif call.name == 'compile':
content = self._container_handle_conclusion(cur_round, call.args,
build_result)
if content is None:
return None

if self._parse_tag(response, 'conclusion'):
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
else:
content = self._container_handle_invalid_tool_usage(
call.args, self.inspect_tool)
return Part.from_function_response(name=call.name,
response={'content': content})

def _container_tool_reaction(self, cur_round: int,
response: GenerationResponse,
build_result: BuildResult) -> Optional[Prompt]:
"""Executes bash command or validates fuzz targets."""
results = []
for call in response.candidates[0].function_calls:
result = self._call_function(cur_round, call, build_result)
if result is None:
return None
results.append(result)
return DefaultTemplateBuilder(self.llm, initial=results).build([])

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
Expand All @@ -278,7 +320,7 @@ def execute(self, result_history: list[Result]) -> BuildResult:
chat_history={self.name: ''})
prompt = self._initial_prompt(result_history)
try:
client = self.llm.get_chat_client(model=self.llm.get_model())
client = self._initialize_chat_session()
while prompt and cur_round < MAX_ROUND:
response = self.chat_llm(cur_round,
client=client,
Expand Down
2 changes: 1 addition & 1 deletion ci/request_pr_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'large-pr-exp.yaml')
BENCHMARK_SET = 'comparison'
LLM_NAME = 'vertex_ai_gemini-1-5'
LLM_CHAT_NAME = 'vertex_ai_gemini-1-5-chat'
LLM_CHAT_NAME = 'vertex_ai_gemini-1-5-chat-tool'
EXP_DELAY = 0
FUZZING_TIMEOUT = 300
REQUEST_CPU = 6
Expand Down
Loading