Skip to content

Commit

Permalink
Backend refactoring:
Browse files Browse the repository at this point in the history
The project structure is updated as follows:
`
src/
| - agent/
  | - architectures/
| - core/
  | - knowledge/
  | - llm/
  | - memory/
  | - tools/
| - api.py
`

This change improves maintainability separating core Agent features with the specific architecture.
The �gent package contains an interface for agent architectures. The core package contains components used by agent architectures.
  • Loading branch information
antoninoLorenzo committed Nov 15, 2024
1 parent 0e752e8 commit 24ffb8e
Show file tree
Hide file tree
Showing 26 changed files with 617 additions and 485 deletions.
15 changes: 0 additions & 15 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1 @@
from src.agent.knowledge import Collection, Document, Store, Topic


def initialize_knowledge(vdb: Store):
"""Used to initialize and keep updated the Knowledge Base.
Already existing Collections will not be overwritten.
:param vdb: the reference to the Knowledge Base"""
available: list[Collection] = Store.get_available_datasets()
print(f"[+] Available Datasets ({[c.title for c in available]})")

existing: list[str] = list(vdb.collections.keys())
print(f"[+] Available Collections ({existing})")

for collection in available:
if collection.title not in existing:
vdb.create_collection(collection, progress_bar=True)
45 changes: 44 additions & 1 deletion src/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,45 @@
"""Core component of the system"""
from src.agent.agent import Agent
from src.core import (
LLM,
AVAILABLE_PROVIDERS,
TOOL_REGISTRY,
Memory
)
from src.agent.agent import Agent, AgentArchitecture
from src.agent.architectures import init_default_architecture


init = {
'default': init_default_architecture
}


def build_agent(
model: str,
inference_endpoint: str,
architecture_name: str = 'default',
provider: str = 'ollama',
provider_key: str = ''
) -> Agent:
if provider not in AVAILABLE_PROVIDERS.keys():
raise RuntimeError(f'{provider} not supported.')
llm_provider = AVAILABLE_PROVIDERS[provider]['class']
key_required = AVAILABLE_PROVIDERS[provider]['key_required']

if key_required and len(provider_key) == 0:
raise RuntimeError(
f'Missing PROVIDER_KEY environment variable for {provider}.'
)

llm = LLM(
model=model,
inference_endpoint=inference_endpoint,
provider=llm_provider,
api_key=provider_key
)
architecture = init[architecture_name](
llm=llm,
tool_registry=TOOL_REGISTRY
)
return Agent(architecture)

189 changes: 82 additions & 107 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,108 @@
"""Contains the class `Agent`, the core of the system."""

from tool_parse import ToolRegistry

from src.agent.llm import LLM, ProviderError
from src.agent.memory import Memory, Message, Role
from src.agent.prompts import PROMPTS, PROMPT_VERSION
"""
This module defines the core classes and interfaces for the AI Penetration
Testing Assistant system. It provides the foundational components for handling
user interactions, managing conversations, and processing user inputs using
different agent architectures.
Classes:
- `AgentArchitecture`: An abstract base class that defines the contract for
various architectures used by the assistant to process user queries.
Implementations of this interface can support different models or
strategies for generating responses.
- `Agent`: A high-level interface that manages interaction with the
penetration testing assistant. It abstracts session management, delegates
user queries to a specific `AgentArchitecture` implementation, handles
session persistence.
"""
from abc import ABC, abstractmethod
from typing import Generator
from src.core import Memory, Message, Role


class AgentArchitecture(ABC):
"""Interface defining the contract for various agent architectures.
This interface abstracts the underlying generation strategies used by
the `Agent` class, allowing the implementation of multiple architectures
that can be easily swapped or extended."""

def __init__(self):
self.memory = Memory()

@abstractmethod
def query(
self,
session_id: int,
user_input: str
) -> Generator:
"""Handles the input from the user and generates responses in a
streaming manner. The exact behavior depends on the specific
implementation of the strategy.
:param session_id: The session identifier.
:param user_input: The user's input query.
:returns: Generator with response text in chunks."""
raise NotImplementedError()

@abstractmethod
def new_session(self, session_id: int):
raise NotImplementedError()


class Agent:
"""Penetration Testing Assistant"""
"""This class serves as a high-level interface for managing the interaction
with the penetration testing assistant. It abstracts session management
and delegates querying to a specific `AgentArchitecture` implementation.
"""

def __init__(
self,
llm: LLM,
tools: str = '',
tool_registry: ToolRegistry | None = None
self,
architecture: AgentArchitecture
):
"""
:param llm: Large Language Model instance
:param tools: documentation of penetration testing tools
:param tool_registry: the available agent tools (ToolRegistry)
:param architecture: the agent architecture implementation that handles
the core query processing and responses.
"""
# Agent Components
self.llm = llm
self.mem = Memory()
self.tool_registry: ToolRegistry | None = tool_registry
if tool_registry is not None and len(tool_registry) > 0:
self.tools = list(self.tool_registry.marshal('base'))
else:
self.tools = []

# Prompts
self.model = self.llm.model
prompts = PROMPTS[self.model][PROMPT_VERSION]
self.system_plan_gen = prompts['plan']['system'].format(tools=tools)
self.user_plan_gen = prompts['plan']['user']
self.system_plan_con = prompts['conversion']['system']
self.user_plan_con = prompts['conversion']['user']

def query(self, sid: int, user_in: str):
"""Performs a query to the Large Language Model, will use RAG
if provided with the necessary tool to perform rag search"""
if not isinstance(user_in, str) or len(user_in) == 0:
raise ValueError(f'Invalid input: {user_in} [{type(user_in)}]')

# ensure session is initialized (otherwise llm has no system prompt)
if sid not in self.mem.sessions.keys():
self.new_session(sid)

# get input for llm
prompt = self.user_plan_gen.format(user=user_in)
usr_msg = Message(Role.USER, prompt)
self.mem.store_message(sid, usr_msg)
messages = self.mem.get_session(sid).message_dict

# call tools
if self.tools:
tool_response = self.llm.tool_query(
messages,
tools=self.tools
)
# tool results aren't persisted
if tool_response['message'].get('tool_calls'):
results = self.invoke_tools(tool_response)
messages.extend(results)

# generate response
try:
response = ''
token_usage = 0
for chunk, tokens in self.llm.query(messages):
yield chunk
response += chunk

if tokens is not None:
token_usage = tokens
except ProviderError:
raise

# update memory
sys_msg = Message(Role.ASSISTANT, response)
self.mem.get_session(sid).tokens = token_usage
self.mem.store_message(sid, sys_msg)

def invoke_tools(self, tool_response):
"""Execute tools (ex. RAG) from llm response"""
results = []

call_stack = []
for tool in tool_response['message']['tool_calls']:
tool_meta = {
'name': tool['function']['name'],
'args': tool['function']['arguments']
}

if tool_meta in call_stack:
continue
try:
res = self.tool_registry.compile(
name=tool_meta['name'],
arguments=tool_meta['args']
)
call_stack.append(tool_meta)
results.append({'role': 'tool', 'content': str(res)})
except Exception:
pass

return results
self.agent = architecture

def query(self, session_id: int, user_input: str):
"""Handles the input from the user and generates responses in a
streaming manner.
This method delegates the query to the specified `AgentArchitecture`
:param session_id: The session identifier.
:param user_input: The user's input query.
:returns: Generator with response text in chunks."""
if not isinstance(user_input, str) or len(user_input) == 0:
raise ValueError(f'Invalid input: {user_input} [{type(user_input)}]')

yield from self.agent.query(session_id, user_input)

def new_session(self, sid: int):
"""Initializes a new conversation"""
self.mem.store_message(sid, Message(Role.SYS, self.system_plan_gen))
self.agent.new_session(sid)

def get_session(self, sid: int):
"""Open existing conversation"""
return self.mem.get_session(sid)
return self.agent.memory.get_session(sid)

def get_sessions(self):
"""Returns list of Session objects"""
return self.mem.get_sessions()
return self.agent.memory.get_sessions()

def save_session(self, sid: int):
"""Saves the specified session to JSON"""
self.mem.save_session(sid)
self.agent.memory.save_session(sid)

def delete_session(self, sid: int):
"""Deletes the specified session"""
self.mem.delete_session(sid)
self.agent.memory.delete_session(sid)

def rename_session(self, sid: int, session_name: str):
"""Rename the specified session"""
self.mem.rename_session(sid, session_name)
self.agent.memory.rename_session(sid, session_name)

4 changes: 4 additions & 0 deletions src/agent/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from src.agent.architectures.default import (
DefaultArchitecture,
init_default_architecture
)
28 changes: 28 additions & 0 deletions src/agent/architectures/default/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import json
from pathlib import Path

from tool_parse import ToolRegistry

from src.core import LLM
from src.agent.architectures.default.architecture import DefaultArchitecture


def init_default_architecture(
llm: LLM,
tool_registry: ToolRegistry
) -> DefaultArchitecture:

with open(
str(Path(__file__).parent / 'prompts.json'),
encoding='utf-8'
) as fp:
prompts = json.load(fp)

return DefaultArchitecture(
llm=llm,
tools=tool_registry,
router_prompt=prompts['router']['content'],
general_prompt=prompts['general']['content'],
reasoning_prompt=prompts['reasoning']['content'],
tool_prompt=prompts['tool']['content']
)
Loading

0 comments on commit 24ffb8e

Please sign in to comment.