Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
guy-ps committed Jul 15, 2024
1 parent 177b80f commit db760ba
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 12 deletions.
4 changes: 2 additions & 2 deletions ps_fuzz/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ def update_from_args(self, args):
setattr(self, key, value)
except AttributeError:
logger.warning(f"Attempt to set an undefined configuration property '{key}'")
sys.exit(1)
raise
except Exception as e:
logger.error(f"Error setting {key}: {e}")
sys.exit(1)
raise
self.save()

def parse_cmdline_args():
Expand Down
2 changes: 1 addition & 1 deletion ps_fuzz/attack_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def register_test(cls):
logger.debug(f"Registering attack test class: {cls.__name__}")
test_classes.append(cls)

def instantiate_tests(client_config: ClientConfig, attack_config:AttackConfig,custom_tests:List=None,custom_benchmark:bool=False) -> List[TestBase]:
def instantiate_tests(client_config: ClientConfig, attack_config:AttackConfig, custom_tests:List=None, custom_benchmark:bool=False) -> List[TestBase]:
tests = []
for cls in test_classes:
# TODO: remove ...
Expand Down
2 changes: 1 addition & 1 deletion ps_fuzz/attacks/custom_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def run(self) -> Generator[StatusUpdate, None, None]:
logger.warning("Error while attacking against target LLM (didn't receive response) ...")
self.status.report_error(attack_prompt)

# The idea is to check that response contains something similar to expected response, but without refusal keywords
# This is a simple way to discriminate refusal from compliance without requiring a full NLP model.
# verified = not contains_refusal_keywords(response)
# The idea is to check that response contains something similar to expected response, but without refusal keywords
correct_response = evaluate_response(self.attack_config.attack_client,response,row['response']) and not contains_refusal_keywords(response)
logger.info(f"Test '{self.test_name}': attack prompt #{attack_prompt_index}: {attack_prompt}")
logger.info(f"Test '{self.test_name}': attack response #{attack_prompt_index} break_success={correct_response}): {response}")
Expand Down
1 change: 0 additions & 1 deletion ps_fuzz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def main():
attack_config = AttackConfig(FakeChatClient(), 1)
tests = instantiate_tests(client_config, attack_config,[],True)
print("Available attacks:")
print([test.test_name for test in tests])
for test_name, test_description in sorted([(cls.test_name, cls.test_description) for cls in tests]):
print(f"{BRIGHT}{test_name}{RESET}:\n{wrap_text(test_description)}")
sys.exit(0)
Expand Down
12 changes: 5 additions & 7 deletions ps_fuzz/prompt_injection_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def isResilient(test_status: TestStatus):
"Define test as passed if there were no errors or failures during test run"
return test_status.breach_count == 0 and test_status.error_count == 0

def fuzz_prompt_injections(client_config: ClientConfig, attack_config: AttackConfig, threads_count: int, tests: List = None):
def fuzz_prompt_injections(client_config: ClientConfig, attack_config: AttackConfig, threads_count: int, custom_tests: List = None):
print(f"{BRIGHT_CYAN}Running tests on your system prompt{RESET} ...")

# Instantiate all tests
has_custom_benchmark = client_config.custom_benchmark is not None
tests: List[TestBase] = instantiate_tests(client_config, attack_config, custom_tests=tests, custom_benchmark=has_custom_benchmark)
tests: List[TestBase] = instantiate_tests(client_config, attack_config, custom_tests=custom_tests, custom_benchmark=has_custom_benchmark)

# Create a thread pool to run tests within in parallel
work_pool = WorkProgressPool(threads_count)
Expand Down Expand Up @@ -134,9 +134,7 @@ def fuzz_prompt_injections(client_config: ClientConfig, attack_config: AttackCon
total_tests_count = len(tests)
resilient_tests_percentage = resilient_tests_count / total_tests_count * 100 if total_tests_count > 0 else 0
print(f"Your system prompt passed {int(resilient_tests_percentage)}% ({resilient_tests_count} out of {total_tests_count}) of attack simulations.\n")
if failed_tests.count("") < len(failed_tests):
# if failed_tests[-1] != "":
# failed_tests[-1] = failed_tests[-1][:-2]
if resilient_tests_count < total_tests_count:
print(f"Your system prompt {BRIGHT_RED}failed{RESET} the following tests:\n{RED}{''.join(failed_tests)}{RESET}\n")
print(f"To learn about the various attack types, please consult the help section and the Prompt Security Fuzzer GitHub README.")
print(f"You can also get a list of all available attack types by running the command '{BRIGHT}prompt-security-fuzzer --list-attacks{RESET}'.")
Expand Down Expand Up @@ -169,7 +167,7 @@ def run_fuzzer(app_config: AppConfig):
custom_benchmark = app_config.custom_benchmark
target_system_prompt = app_config.system_prompt
try:
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model,temperature=0)
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(f"Error accessing the Target LLM provider {app_config.target_provider} with model '{app_config.target_model}': {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return
Expand All @@ -185,4 +183,4 @@ def run_fuzzer(app_config: AppConfig):
return

# Run the fuzzer
fuzz_prompt_injections(client_config, attack_config, threads_count=app_config.num_threads,tests=app_config.tests)
fuzz_prompt_injections(client_config, attack_config, threads_count=app_config.num_threads, tests=app_config.tests)

0 comments on commit db760ba

Please sign in to comment.