Skip to content

Commit

Permalink
Fix multiple threads share one logger (#720)
Browse files Browse the repository at this point in the history
The original code
([_fuzzing_pipelines](https://github.com/google/oss-fuzz-gen/blob/c9f87defe821ab66f123b7730b59eb2a93c6608b/run_one_experiment.py#L336),
[_fuzzing_pipeline](https://github.com/google/oss-fuzz-gen/blob/c9f87defe821ab66f123b7730b59eb2a93c6608b/run_one_experiment.py#L321),
[get_trial_logger](https://github.com/google/oss-fuzz-gen/blob/c9f87defe821ab66f123b7730b59eb2a93c6608b/logger.py#L126))
will cause multiple threads in the thread pool to share a logger. Now
use threading.local() to create a separate logger for each thread.
  • Loading branch information
maoyixie authored Nov 24, 2024
1 parent cd048ba commit a1e2fc0
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 78 deletions.
22 changes: 17 additions & 5 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]:
return tool
return None

def chat_llm(self, cur_round: int, client: Any, prompt: Prompt) -> str:
def chat_llm(self, cur_round: int, client: Any, prompt: Prompt,
trial: int) -> str:
"""Chat with LLM."""
logger.info('<CHAT PROMPT:ROUND %02d>%s</CHAT PROMPT:ROUND %02d>',
cur_round, prompt.get(), cur_round)
cur_round,
prompt.get(),
cur_round,
trial=trial)
response = self.llm.chat_llm(client=client, prompt=prompt)
logger.info('<CHAT RESPONSE:ROUND %02d>%s</CHAT RESPONSE:ROUND %02d>',
cur_round, response, cur_round)
cur_round,
response,
cur_round,
trial=trial)
return response

def _parse_tag(self, response: str, tag: str) -> str:
Expand Down Expand Up @@ -89,11 +96,16 @@ def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _sleep_random_duration(self, min_sec: int = 1, max_sec: int = 60) -> None:
def _sleep_random_duration(
self,
trial: int,
min_sec: int = 1,
max_sec: int = 60,
) -> None:
"""Sleeps for a random duration between min_sec and max_sec. Agents uses
this to avoid exceeding quota limit (e.g., LLM query frequency)."""
duration = random.randint(min_sec, max_sec)
logger.debug('Sleeping for %d before the next query', duration)
logger.debug('Sleeping for %d before the next query', duration, trial=trial)
time.sleep(duration)

@classmethod
Expand Down
94 changes: 64 additions & 30 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
from experiment.benchmark import Benchmark
from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS
from llm_toolkit.prompt_builder import (DefaultTemplateBuilder,
PrototyperTemplateBuilder)
from llm_toolkit.prompts import Prompt
Expand Down Expand Up @@ -48,23 +47,31 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
self._parse_tag(response, 'fuzz target'))
build_result.fuzz_target_source = fuzz_target_source
if fuzz_target_source:
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round,
fuzz_target_source)
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s',
cur_round,
fuzz_target_source,
trial=build_result.trial)
else:
logger.error('ROUND %02d No fuzz target source code in conclusion: %s',
cur_round, response)
cur_round,
response,
trial=build_result.trial)

build_script_source = self._filter_code(
self._parse_tag(response, 'build script'))
# Sometimes LLM adds chronos, which makes no sense for new build scripts.
build_result.build_script_source = build_script_source.replace(
'source /src/chronos.sh', '')
if build_script_source:
logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round,
build_script_source)
logger.debug('ROUND %02d Parsed build script from LLM: %s',
cur_round,
build_script_source,
trial=build_result.trial)
else:
logger.debug('ROUND %02d No build script in conclusion: %s', cur_round,
response)
logger.debug('ROUND %02d No build script in conclusion: %s',
cur_round,
response,
trial=build_result.trial)

def _update_build_result(self, build_result: BuildResult,
compile_process: sp.CompletedProcess, status: bool,
Expand All @@ -84,20 +91,22 @@ def _validate_fuzz_target_and_build_script(self, cur_round: int,
# 2. Recompile with the modified build script, if any.
build_script_source = build_result.build_script_source

logger.info('First compile fuzz target without modifying build script.')
logger.info('First compile fuzz target without modifying build script.',
trial=build_result.trial)
build_result.build_script_source = ''
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)

if not build_result.success and build_script_source:
logger.info('Then compile fuzz target with modified build script.')
logger.info('Then compile fuzz target with modified build script.',
trial=build_result.trial)
build_result.build_script_source = build_script_source
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)

def _validate_fuzz_target_references_function(
self, compilation_tool: ProjectContainerTool, benchmark: Benchmark,
cur_round: int) -> bool:
cur_round: int, trial: int) -> bool:
"""Validates if the LLM generated fuzz target assembly code references
function-under-test."""
disassemble_result = compilation_tool.execute(
Expand All @@ -106,10 +115,13 @@ def _validate_fuzz_target_references_function(
function_referenced = (disassemble_result.returncode == 0 and
benchmark.function_name in disassemble_result.stdout)
logger.debug('ROUND %02d Final fuzz target function referenced: %s',
cur_round, function_referenced)
cur_round,
function_referenced,
trial=trial)
if not function_referenced:
logger.debug('ROUND %02d Final fuzz target function not referenced',
cur_round)
cur_round,
trial=trial)
return function_referenced

def _validate_fuzz_target_and_build_script_via_compile(
Expand All @@ -133,25 +145,33 @@ def _validate_fuzz_target_and_build_script_via_compile(
file_content=build_result.build_script_source))

# Recompile.
logger.info('===== ROUND %02d Recompile =====', cur_round)
logger.info('===== ROUND %02d Recompile =====',
cur_round,
trial=build_result.trial)
start_time = time.time()
compile_process = compilation_tool.compile()
end_time = time.time()
logger.debug('ROUND %02d compilation time: %s', cur_round,
timedelta(seconds=end_time - start_time))
logger.debug('ROUND %02d compilation time: %s',
cur_round,
timedelta(seconds=end_time - start_time),
trial=build_result.trial)
compile_succeed = compile_process.returncode == 0
logger.debug('ROUND %02d Fuzz target compiles: %s', cur_round,
compile_succeed)
logger.debug('ROUND %02d Fuzz target compiles: %s',
cur_round,
compile_succeed,
trial=build_result.trial)

# Double-check binary.
ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}')
binary_exists = ls_result.returncode == 0
logger.debug('ROUND %02d Final fuzz target binary exists: %s', cur_round,
binary_exists)
logger.debug('ROUND %02d Final fuzz target binary exists: %s',
cur_round,
binary_exists,
trial=build_result.trial)

# Validate if function-under-test is referenced by the fuzz target.
function_referenced = self._validate_fuzz_target_references_function(
compilation_tool, benchmark, cur_round)
compilation_tool, benchmark, cur_round, build_result.trial)

compilation_tool.terminate()
self._update_build_result(build_result,
Expand All @@ -164,18 +184,24 @@ def _container_handle_conclusion(
build_result: BuildResult) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
logger.info('----- ROUND %02d Received conclusion -----', cur_round)
logger.info('----- ROUND %02d Received conclusion -----',
cur_round,
trial=build_result.trial)

self._update_fuzz_target_and_build_script(cur_round, response, build_result)

self._validate_fuzz_target_and_build_script(cur_round, build_result)
if build_result.success:
logger.info('***** Prototyper succeded in %02d rounds *****', cur_round)
logger.info('***** Prototyper succeded in %02d rounds *****',
cur_round,
trial=build_result.trial)
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
logger.info('***** Failed to recompile in %02d rounds *****', cur_round)
logger.info('***** Failed to recompile in %02d rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'Failed to build fuzz target. Here is the fuzz target, build script, '
'compliation command, and other compilation runtime output. Analyze '
Expand Down Expand Up @@ -205,7 +231,9 @@ def _container_handle_conclusion(
elif not build_result.is_function_referenced:
logger.info(
'***** Fuzz target does not reference function-under-test in %02d '
'rounds *****', cur_round)
'rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'The fuzz target builds successfully, but the target function '
f'`{build_result.benchmark.function_signature}` was not used by '
Expand All @@ -229,14 +257,16 @@ def _container_tool_reaction(self, cur_round: int, response: str,
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
response)
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
logger.info('Executing Prototyper')
last_result = result_history[-1]
logger.info('Executing Prototyper', trial=last_result.trial)
benchmark = last_result.benchmark
self.inspect_tool = ProjectContainerTool(benchmark, name='inspect')
self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null')
Expand All @@ -250,13 +280,17 @@ def execute(self, result_history: list[Result]) -> BuildResult:
try:
client = self.llm.get_chat_client(model=self.llm.get_model())
while prompt and cur_round < MAX_ROUND:
response = self.chat_llm(cur_round, client=client, prompt=prompt)
response = self.chat_llm(cur_round,
client=client,
prompt=prompt,
trial=last_result.trial)
prompt = self._container_tool_reaction(cur_round, response,
build_result)
cur_round += 1
finally:
# Cleanup: stop and remove the container
logger.debug('Stopping and removing the inspect container %s',
self.inspect_tool.container_id)
self.inspect_tool.container_id,
trial=last_result.trial)
self.inspect_tool.terminate()
return build_result
71 changes: 34 additions & 37 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

FINAL_RESULT_JSON = 'result.json'

_trial_logger = None


class CustomLoggerAdapter(logging.LoggerAdapter):
"""A note-taker to log and record experiment status, key info, and final
Expand Down Expand Up @@ -61,76 +59,76 @@ def write_chat_history(self, result: Result) -> None:

def debug(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().debug(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).debug(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def info(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().info(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).info(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def warning(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().warning(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).warning(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def error(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().error(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).error(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def get_trial_logger(name: str = __name__,
trial: int = 0,
level=logging.DEBUG) -> CustomLoggerAdapter:
"""Sets up or retrieves the singleton instance of CustomLoggerAdapter."""
global _trial_logger
if _trial_logger:
return _trial_logger

"""Sets up or retrieves a thread-local CustomLoggerAdapter for each thread."""
logger = logging.getLogger(name)
if not logger.handlers:
formatter = logging.Formatter(
Expand All @@ -143,5 +141,4 @@ def get_trial_logger(name: str = __name__,
logger.setLevel(level)
logger.propagate = False

_trial_logger = CustomLoggerAdapter(logger, {'trial': trial})
return _trial_logger
return CustomLoggerAdapter(logger, {'trial': trial})
Loading

0 comments on commit a1e2fc0

Please sign in to comment.