Skip to content

Commit

Permalink
Fixed Logger class to be used with/without orchestrator
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Oct 16, 2024
1 parent a06aa41 commit 7047a69
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/chat-chainlit-app/ollamaAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def handle_streaming_response(self, messages: List[Dict[str, str]]) -> Con
)

except Exception as error:
Logger.logger.error("Error getting stream from Ollama model:", error)
Logger.error("Error getting stream from Ollama model:", error)
raise error


Expand Down
10 changes: 5 additions & 5 deletions python/src/multi_agent_orchestrator/agents/chain_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,25 @@ async def process_request(
current_input = response.content[0]['text']
final_response = response
else:
Logger.logger.warning(f"Agent {agent.name} returned no text content.")
Logger.warn(f"Agent {agent.name} returned no text content.")
return self.create_default_response()
elif self.is_async_iterable(response):
if not is_last_agent:
Logger.logger.warning(f"Intermediate agent {agent.name} returned a streaming response, which is not allowed.")
Logger.warn(f"Intermediate agent {agent.name} returned a streaming response, which is not allowed.")
return self.create_default_response()
# It's the last agent and streaming is allowed
final_response = response
else:
Logger.logger.warning(f"Agent {agent.name} returned an invalid response type.")
Logger.warn(f"Agent {agent.name} returned an invalid response type.")
return self.create_default_response()

# If it's not the last agent, ensure we have a non-streaming response to pass to the next agent
if not is_last_agent and not self.is_conversation_message(final_response):
Logger.logger.error(f"Expected non-streaming response from intermediate agent {agent.name}")
Logger.error(f"Expected non-streaming response from intermediate agent {agent.name}")
return self.create_default_response()

except Exception as error:
Logger.logger.error(f"Error processing request with agent {agent.name}:{str(error)}")
Logger.error(f"Error processing request with agent {agent.name}:{str(error)}")
raise f"Error processing request with agent {agent.name}:{str(error)}"

return final_response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def process_request(self,
issues.append(custom_issue)

if issues:
Logger.logger.warning(f"Content filter issues detected: {'; '.join(issues)}")
Logger.warn(f"Content filter issues detected: {'; '.join(issues)}")
return None # Return None to indicate content should not be processed further

# If no issues, return the original input as a ConversationMessage
Expand All @@ -95,7 +95,7 @@ async def process_request(self,
)

except Exception as error:
Logger.logger.error(f"Error in ComprehendContentFilterAgent:{str(error)}")
Logger.error(f"Error in ComprehendContentFilterAgent:{str(error)}")
raise error

def add_custom_check(self, check: CheckFunction):
Expand Down
13 changes: 6 additions & 7 deletions python/src/multi_agent_orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ async def dispatch_to_agent(self,
agent_chat_history = await self.storage.fetch_chat(user_id, session_id, selected_agent.id)

self.logger.print_chat_history(agent_chat_history, selected_agent.id)
#self.logger.info(f"Routing intent '{user_input}' to {selected_agent.id} ...")

response = await self.measure_execution_time(
f"Agent {selected_agent.name} | Processing request",
Expand All @@ -110,9 +109,9 @@ async def route_request(self,
session_id: str,
additional_params: Dict[str, str] = {}) -> AgentResponse:
self.execution_times.clear()
chat_history = await self.storage.fetch_all_chats(user_id, session_id) or []

try:
chat_history = await self.storage.fetch_all_chats(user_id, session_id) or []
classifier_result:ClassifierResult = await self.measure_execution_time(
"Classifying user intent",
lambda: self.classifier.classify(user_input, chat_history)
Expand Down Expand Up @@ -210,13 +209,13 @@ async def route_request(self,

def print_intent(self, user_input: str, intent_classifier_result: ClassifierResult) -> None:
"""Print the classified intent."""
Logger.log_header('Classified Intent')
Logger.logger.info(f"> Text: {user_input}")
Logger.logger.info(f"> Selected Agent: {intent_classifier_result.selected_agent.name \
self.logger.log_header('Classified Intent')
self.logger.info(f"> Text: {user_input}")
self.logger.info(f"> Selected Agent: {intent_classifier_result.selected_agent.name \
if intent_classifier_result.selected_agent \
else 'No agent selected'}")
Logger.logger.info(f"> Confidence: {intent_classifier_result.confidence:.2f}")
Logger.logger.info('')
self.logger.info(f"> Confidence: {intent_classifier_result.confidence:.2f}")
self.logger.info('')

async def measure_execution_time(self, timer_name: str, fn):
if not self.config.LOG_EXECUTION_TIMES:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def save_chat_message(
try:
self.table.put_item(Item=item)
except Exception as error:
Logger.logger.error(f"Error saving conversation to DynamoDB:{str(error)}")
Logger.error(f"Error saving conversation to DynamoDB:{str(error)}")
raise error

return self._remove_timestamps(trimmed_conversation)
Expand All @@ -76,7 +76,7 @@ async def fetch_chat(
)
return self._remove_timestamps(stored_messages)
except Exception as error:
Logger.logger.error(f"Error getting conversation from DynamoDB:{str(error)}")
Logger.error(f"Error getting conversation from DynamoDB:{str(error)}")
raise error

async def fetch_chat_with_timestamp(
Expand All @@ -93,7 +93,7 @@ async def fetch_chat_with_timestamp(
)
return stored_messages
except Exception as error:
Logger.logger.error(f"Error getting conversation from DynamoDB: {str(error)}")
Logger.error(f"Error getting conversation from DynamoDB: {str(error)}")
raise error

async def fetch_all_chats(self, user_id: str, session_id: str) -> List[ConversationMessage]:
Expand All @@ -112,7 +112,7 @@ async def fetch_all_chats(self, user_id: str, session_id: str) -> List[Conversat
all_chats = []
for item in response['Items']:
if not isinstance(item.get('conversation'), list):
Logger.logger.error(f"Unexpected item structure:{item}")
Logger.error(f"Unexpected item structure:{item}")
continue

agent_id = item['SK'].split('#')[1]
Expand All @@ -134,7 +134,7 @@ async def fetch_all_chats(self, user_id: str, session_id: str) -> List[Conversat
all_chats.sort(key=lambda x: x.timestamp)
return self._remove_timestamps(all_chats)
except Exception as error:
Logger.logger.error(f"Error querying conversations from DynamoDB:{str(error)}")
Logger.error(f"Error querying conversations from DynamoDB:{str(error)}")
raise error

def _generate_key(self, user_id: str, session_id: str, agent_id: str) -> str:
Expand Down
56 changes: 36 additions & 20 deletions python/src/multi_agent_orchestrator/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,57 @@
logging.basicConfig(level=logging.INFO)

class Logger:
_instance = None
_logger = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self,
config: Optional[Dict[str, bool]] = None,
logger: Optional[logging.Logger] = logging.getLogger(__name__)):
logger: Optional[logging.Logger] = None):
if not hasattr(self, 'initialized'):
Logger._logger = logger or logging.getLogger(__name__)
self.initialized = True
self.config: OrchestratorConfig = config or OrchestratorConfig()
self.set_logger(logger or logging.getLogger(__name__))

@classmethod
def get_logger(cls):
if cls._logger is None:
cls._logger = logging.getLogger(__name__)
return cls._logger

@classmethod
def set_logger(cls, logger: Any) -> None:
cls.logger = logger
cls._logger = logger

@classmethod
def info(cls, message: str, *args: Any) -> None:
"""Log an info message."""
cls.logger.info(message, *args)
cls.get_logger().info(message, *args)

@classmethod
def warn(cls, message: str, *args: Any) -> None:
"""Log a warning message."""
cls.logger.info(message, *args)
cls.get_logger().info(message, *args)

@classmethod
def error(cls, message: str, *args: Any) -> None:
"""Log an error message."""
cls.logger.error(message, *args)
cls.get_logger().error(message, *args)

@classmethod
def debug(cls, message: str, *args: Any) -> None:
"""Log a debug message."""
cls.logger.debug(message, *args)
cls.get_logger().debug(message, *args)

@classmethod
def log_header(cls, title: str) -> None:
"""Log a header with the given title."""
cls.logger.info(f"\n** {title.upper()} **")
cls.logger.info('=' * (len(title) + 6))
cls.get_logger().info(f"\n** {title.upper()} **")
cls.get_logger().info('=' * (len(title) + 6))

def print_chat_history(self,
chat_history: List[ConversationMessage],
Expand All @@ -52,39 +68,39 @@ def print_chat_history(self,
return

title = f"Agent {agent_id} Chat History" if is_agent_chat else 'Classifier Chat History'
Logger.log_header(title)
self.log_header(title)

if not chat_history:
Logger.logger.info('> - None -')
self.get_logger().info('> - None -')
else:
for index, message in enumerate(chat_history, 1):
role = message.role.upper()
content = message.content
text = content[0] if isinstance(content, list) else content
text = text.get('text', '') if isinstance(text, dict) else str(text)
trimmed_text = f"{text[:80]}..." if len(text) > 80 else text
Logger.logger.info(f"> {index}. {role}: {trimmed_text}")
Logger.logger.info('')
self.get_logger().info(f"> {index}. {role}: {trimmed_text}")
self.get_logger().info('')

def log_classifier_output(self, output: Any, is_raw: bool = False) -> None:
"""Log the classifier output."""
if (is_raw and not self.config.LOG_CLASSIFIER_RAW_OUTPUT) or \
(not is_raw and not self.config.LOG_CLASSIFIER_OUTPUT):
return

Logger.log_header('Raw Classifier Output' if is_raw else 'Processed Classifier Output')
Logger.logger.info(output if is_raw else json.dumps(output, indent=2))
Logger.logger.info('')
self.log_header('Raw Classifier Output' if is_raw else 'Processed Classifier Output')
self.get_logger().info(output if is_raw else json.dumps(output, indent=2))
self.get_logger().info('')

def print_execution_times(self, execution_times: Dict[str, float]) -> None:
"""Print execution times."""
if not self.config.LOG_EXECUTION_TIMES:
return

Logger.log_header('Execution Times')
self.log_header('Execution Times')
if not execution_times:
Logger.logger.info('> - None -')
self.get_logger().info('> - None -')
else:
for timer_name, duration in execution_times.items():
Logger.logger.info(f"> {timer_name}: {duration}s")
Logger.logger.info('')
self.get_logger().info(f"> {timer_name}: {duration}s")
self.get_logger().info('')
5 changes: 3 additions & 2 deletions python/src/tests/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ def mock_logger(mocker):
def test_logger_initialization():
logger = Logger()
assert isinstance(logger.config, OrchestratorConfig)
assert isinstance(logger.logger, logging.Logger)
assert isinstance(logger._logger, logging.Logger)

def test_logger_initialization_with_custom_config():
custom_config = OrchestratorConfig(**{'LOG_AGENT_CHAT': True, 'LOG_CLASSIFIER_CHAT': False})
print(custom_config)
logger = Logger(config=custom_config)
assert logger.config == custom_config

def test_set_logger(mock_logger):
Logger.set_logger(mock_logger)
assert Logger.logger == mock_logger
assert Logger._logger == mock_logger

@pytest.mark.parametrize("log_method", ["info", "info", "error", "debug"])
def test_log_methods(mock_logger, log_method):
Expand Down

0 comments on commit 7047a69

Please sign in to comment.