-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The project structure is updated as follows: ` src/ | - agent/ | - architectures/ | - core/ | - knowledge/ | - llm/ | - memory/ | - tools/ | - api.py ` This change improves maintainability separating core Agent features with the specific architecture. The �gent package contains an interface for agent architectures. The core package contains components used by agent architectures.
- Loading branch information
1 parent
0e752e8
commit 24ffb8e
Showing
26 changed files
with
617 additions
and
485 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1 @@ | ||
from src.agent.knowledge import Collection, Document, Store, Topic | ||
|
||
|
||
def initialize_knowledge(vdb: Store): | ||
"""Used to initialize and keep updated the Knowledge Base. | ||
Already existing Collections will not be overwritten. | ||
:param vdb: the reference to the Knowledge Base""" | ||
available: list[Collection] = Store.get_available_datasets() | ||
print(f"[+] Available Datasets ({[c.title for c in available]})") | ||
|
||
existing: list[str] = list(vdb.collections.keys()) | ||
print(f"[+] Available Collections ({existing})") | ||
|
||
for collection in available: | ||
if collection.title not in existing: | ||
vdb.create_collection(collection, progress_bar=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,45 @@ | ||
"""Core component of the system""" | ||
from src.agent.agent import Agent | ||
from src.core import ( | ||
LLM, | ||
AVAILABLE_PROVIDERS, | ||
TOOL_REGISTRY, | ||
Memory | ||
) | ||
from src.agent.agent import Agent, AgentArchitecture | ||
from src.agent.architectures import init_default_architecture | ||
|
||
|
||
init = { | ||
'default': init_default_architecture | ||
} | ||
|
||
|
||
def build_agent( | ||
model: str, | ||
inference_endpoint: str, | ||
architecture_name: str = 'default', | ||
provider: str = 'ollama', | ||
provider_key: str = '' | ||
) -> Agent: | ||
if provider not in AVAILABLE_PROVIDERS.keys(): | ||
raise RuntimeError(f'{provider} not supported.') | ||
llm_provider = AVAILABLE_PROVIDERS[provider]['class'] | ||
key_required = AVAILABLE_PROVIDERS[provider]['key_required'] | ||
|
||
if key_required and len(provider_key) == 0: | ||
raise RuntimeError( | ||
f'Missing PROVIDER_KEY environment variable for {provider}.' | ||
) | ||
|
||
llm = LLM( | ||
model=model, | ||
inference_endpoint=inference_endpoint, | ||
provider=llm_provider, | ||
api_key=provider_key | ||
) | ||
architecture = init[architecture_name]( | ||
llm=llm, | ||
tool_registry=TOOL_REGISTRY | ||
) | ||
return Agent(architecture) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,133 +1,108 @@ | ||
"""Contains the class `Agent`, the core of the system.""" | ||
|
||
from tool_parse import ToolRegistry | ||
|
||
from src.agent.llm import LLM, ProviderError | ||
from src.agent.memory import Memory, Message, Role | ||
from src.agent.prompts import PROMPTS, PROMPT_VERSION | ||
""" | ||
This module defines the core classes and interfaces for the AI Penetration | ||
Testing Assistant system. It provides the foundational components for handling | ||
user interactions, managing conversations, and processing user inputs using | ||
different agent architectures. | ||
Classes: | ||
- `AgentArchitecture`: An abstract base class that defines the contract for | ||
various architectures used by the assistant to process user queries. | ||
Implementations of this interface can support different models or | ||
strategies for generating responses. | ||
- `Agent`: A high-level interface that manages interaction with the | ||
penetration testing assistant. It abstracts session management, delegates | ||
user queries to a specific `AgentArchitecture` implementation, handles | ||
session persistence. | ||
""" | ||
from abc import ABC, abstractmethod | ||
from typing import Generator | ||
from src.core import Memory, Message, Role | ||
|
||
|
||
class AgentArchitecture(ABC): | ||
"""Interface defining the contract for various agent architectures. | ||
This interface abstracts the underlying generation strategies used by | ||
the `Agent` class, allowing the implementation of multiple architectures | ||
that can be easily swapped or extended.""" | ||
|
||
def __init__(self): | ||
self.memory = Memory() | ||
|
||
@abstractmethod | ||
def query( | ||
self, | ||
session_id: int, | ||
user_input: str | ||
) -> Generator: | ||
"""Handles the input from the user and generates responses in a | ||
streaming manner. The exact behavior depends on the specific | ||
implementation of the strategy. | ||
:param session_id: The session identifier. | ||
:param user_input: The user's input query. | ||
:returns: Generator with response text in chunks.""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def new_session(self, session_id: int): | ||
raise NotImplementedError() | ||
|
||
|
||
class Agent: | ||
"""Penetration Testing Assistant""" | ||
"""This class serves as a high-level interface for managing the interaction | ||
with the penetration testing assistant. It abstracts session management | ||
and delegates querying to a specific `AgentArchitecture` implementation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
llm: LLM, | ||
tools: str = '', | ||
tool_registry: ToolRegistry | None = None | ||
self, | ||
architecture: AgentArchitecture | ||
): | ||
""" | ||
:param llm: Large Language Model instance | ||
:param tools: documentation of penetration testing tools | ||
:param tool_registry: the available agent tools (ToolRegistry) | ||
:param architecture: the agent architecture implementation that handles | ||
the core query processing and responses. | ||
""" | ||
# Agent Components | ||
self.llm = llm | ||
self.mem = Memory() | ||
self.tool_registry: ToolRegistry | None = tool_registry | ||
if tool_registry is not None and len(tool_registry) > 0: | ||
self.tools = list(self.tool_registry.marshal('base')) | ||
else: | ||
self.tools = [] | ||
|
||
# Prompts | ||
self.model = self.llm.model | ||
prompts = PROMPTS[self.model][PROMPT_VERSION] | ||
self.system_plan_gen = prompts['plan']['system'].format(tools=tools) | ||
self.user_plan_gen = prompts['plan']['user'] | ||
self.system_plan_con = prompts['conversion']['system'] | ||
self.user_plan_con = prompts['conversion']['user'] | ||
|
||
def query(self, sid: int, user_in: str): | ||
"""Performs a query to the Large Language Model, will use RAG | ||
if provided with the necessary tool to perform rag search""" | ||
if not isinstance(user_in, str) or len(user_in) == 0: | ||
raise ValueError(f'Invalid input: {user_in} [{type(user_in)}]') | ||
|
||
# ensure session is initialized (otherwise llm has no system prompt) | ||
if sid not in self.mem.sessions.keys(): | ||
self.new_session(sid) | ||
|
||
# get input for llm | ||
prompt = self.user_plan_gen.format(user=user_in) | ||
usr_msg = Message(Role.USER, prompt) | ||
self.mem.store_message(sid, usr_msg) | ||
messages = self.mem.get_session(sid).message_dict | ||
|
||
# call tools | ||
if self.tools: | ||
tool_response = self.llm.tool_query( | ||
messages, | ||
tools=self.tools | ||
) | ||
# tool results aren't persisted | ||
if tool_response['message'].get('tool_calls'): | ||
results = self.invoke_tools(tool_response) | ||
messages.extend(results) | ||
|
||
# generate response | ||
try: | ||
response = '' | ||
token_usage = 0 | ||
for chunk, tokens in self.llm.query(messages): | ||
yield chunk | ||
response += chunk | ||
|
||
if tokens is not None: | ||
token_usage = tokens | ||
except ProviderError: | ||
raise | ||
|
||
# update memory | ||
sys_msg = Message(Role.ASSISTANT, response) | ||
self.mem.get_session(sid).tokens = token_usage | ||
self.mem.store_message(sid, sys_msg) | ||
|
||
def invoke_tools(self, tool_response): | ||
"""Execute tools (ex. RAG) from llm response""" | ||
results = [] | ||
|
||
call_stack = [] | ||
for tool in tool_response['message']['tool_calls']: | ||
tool_meta = { | ||
'name': tool['function']['name'], | ||
'args': tool['function']['arguments'] | ||
} | ||
|
||
if tool_meta in call_stack: | ||
continue | ||
try: | ||
res = self.tool_registry.compile( | ||
name=tool_meta['name'], | ||
arguments=tool_meta['args'] | ||
) | ||
call_stack.append(tool_meta) | ||
results.append({'role': 'tool', 'content': str(res)}) | ||
except Exception: | ||
pass | ||
|
||
return results | ||
self.agent = architecture | ||
|
||
def query(self, session_id: int, user_input: str): | ||
"""Handles the input from the user and generates responses in a | ||
streaming manner. | ||
This method delegates the query to the specified `AgentArchitecture` | ||
:param session_id: The session identifier. | ||
:param user_input: The user's input query. | ||
:returns: Generator with response text in chunks.""" | ||
if not isinstance(user_input, str) or len(user_input) == 0: | ||
raise ValueError(f'Invalid input: {user_input} [{type(user_input)}]') | ||
|
||
yield from self.agent.query(session_id, user_input) | ||
|
||
def new_session(self, sid: int): | ||
"""Initializes a new conversation""" | ||
self.mem.store_message(sid, Message(Role.SYS, self.system_plan_gen)) | ||
self.agent.new_session(sid) | ||
|
||
def get_session(self, sid: int): | ||
"""Open existing conversation""" | ||
return self.mem.get_session(sid) | ||
return self.agent.memory.get_session(sid) | ||
|
||
def get_sessions(self): | ||
"""Returns list of Session objects""" | ||
return self.mem.get_sessions() | ||
return self.agent.memory.get_sessions() | ||
|
||
def save_session(self, sid: int): | ||
"""Saves the specified session to JSON""" | ||
self.mem.save_session(sid) | ||
self.agent.memory.save_session(sid) | ||
|
||
def delete_session(self, sid: int): | ||
"""Deletes the specified session""" | ||
self.mem.delete_session(sid) | ||
self.agent.memory.delete_session(sid) | ||
|
||
def rename_session(self, sid: int, session_name: str): | ||
"""Rename the specified session""" | ||
self.mem.rename_session(sid, session_name) | ||
self.agent.memory.rename_session(sid, session_name) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from src.agent.architectures.default import ( | ||
DefaultArchitecture, | ||
init_default_architecture | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import json | ||
from pathlib import Path | ||
|
||
from tool_parse import ToolRegistry | ||
|
||
from src.core import LLM | ||
from src.agent.architectures.default.architecture import DefaultArchitecture | ||
|
||
|
||
def init_default_architecture( | ||
llm: LLM, | ||
tool_registry: ToolRegistry | ||
) -> DefaultArchitecture: | ||
|
||
with open( | ||
str(Path(__file__).parent / 'prompts.json'), | ||
encoding='utf-8' | ||
) as fp: | ||
prompts = json.load(fp) | ||
|
||
return DefaultArchitecture( | ||
llm=llm, | ||
tools=tool_registry, | ||
router_prompt=prompts['router']['content'], | ||
general_prompt=prompts['general']['content'], | ||
reasoning_prompt=prompts['reasoning']['content'], | ||
tool_prompt=prompts['tool']['content'] | ||
) |
Oops, something went wrong.