From 7047a69de6c8c1cb5f061983f3046435a1d6cd79 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 16 Oct 2024 18:00:53 +0200 Subject: [PATCH] Fixed Logger class to be used with/without orchestrator --- examples/chat-chainlit-app/ollamaAgent.py | 2 +- .../agents/chain_agent.py | 10 ++-- .../agents/comprehend_filter_agent.py | 4 +- .../multi_agent_orchestrator/orchestrator.py | 13 ++--- .../storage/dynamodb_chat_storage.py | 10 ++-- .../multi_agent_orchestrator/utils/logger.py | 56 ++++++++++++------- python/src/tests/utils/test_logger.py | 5 +- 7 files changed, 58 insertions(+), 42 deletions(-) diff --git a/examples/chat-chainlit-app/ollamaAgent.py b/examples/chat-chainlit-app/ollamaAgent.py index ba0fe9e5..0ca47467 100644 --- a/examples/chat-chainlit-app/ollamaAgent.py +++ b/examples/chat-chainlit-app/ollamaAgent.py @@ -34,7 +34,7 @@ async def handle_streaming_response(self, messages: List[Dict[str, str]]) -> Con ) except Exception as error: - Logger.logger.error("Error getting stream from Ollama model:", error) + Logger.error("Error getting stream from Ollama model:", error) raise error diff --git a/python/src/multi_agent_orchestrator/agents/chain_agent.py b/python/src/multi_agent_orchestrator/agents/chain_agent.py index aba22055..ac8dee98 100644 --- a/python/src/multi_agent_orchestrator/agents/chain_agent.py +++ b/python/src/multi_agent_orchestrator/agents/chain_agent.py @@ -43,25 +43,25 @@ async def process_request( current_input = response.content[0]['text'] final_response = response else: - Logger.logger.warning(f"Agent {agent.name} returned no text content.") + Logger.warn(f"Agent {agent.name} returned no text content.") return self.create_default_response() elif self.is_async_iterable(response): if not is_last_agent: - Logger.logger.warning(f"Intermediate agent {agent.name} returned a streaming response, which is not allowed.") + Logger.warn(f"Intermediate agent {agent.name} returned a streaming response, which is not allowed.") return self.create_default_response() # It's the last agent and streaming is allowed final_response = response else: - Logger.logger.warning(f"Agent {agent.name} returned an invalid response type.") + Logger.warn(f"Agent {agent.name} returned an invalid response type.") return self.create_default_response() # If it's not the last agent, ensure we have a non-streaming response to pass to the next agent if not is_last_agent and not self.is_conversation_message(final_response): - Logger.logger.error(f"Expected non-streaming response from intermediate agent {agent.name}") + Logger.error(f"Expected non-streaming response from intermediate agent {agent.name}") return self.create_default_response() except Exception as error: - Logger.logger.error(f"Error processing request with agent {agent.name}:{str(error)}") + Logger.error(f"Error processing request with agent {agent.name}:{str(error)}") raise f"Error processing request with agent {agent.name}:{str(error)}" return final_response diff --git a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py index 0943d5be..4d7f07f0 100644 --- a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py +++ b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py @@ -85,7 +85,7 @@ async def process_request(self, issues.append(custom_issue) if issues: - Logger.logger.warning(f"Content filter issues detected: {'; '.join(issues)}") + Logger.warn(f"Content filter issues detected: {'; '.join(issues)}") return None # Return None to indicate content should not be processed further # If no issues, return the original input as a ConversationMessage @@ -95,7 +95,7 @@ async def process_request(self, ) except Exception as error: - Logger.logger.error(f"Error in ComprehendContentFilterAgent:{str(error)}") + Logger.error(f"Error in ComprehendContentFilterAgent:{str(error)}") raise error def add_custom_check(self, check: CheckFunction): diff --git a/python/src/multi_agent_orchestrator/orchestrator.py b/python/src/multi_agent_orchestrator/orchestrator.py index 476bef74..a97c7070 100644 --- a/python/src/multi_agent_orchestrator/orchestrator.py +++ b/python/src/multi_agent_orchestrator/orchestrator.py @@ -91,7 +91,6 @@ async def dispatch_to_agent(self, agent_chat_history = await self.storage.fetch_chat(user_id, session_id, selected_agent.id) self.logger.print_chat_history(agent_chat_history, selected_agent.id) - #self.logger.info(f"Routing intent '{user_input}' to {selected_agent.id} ...") response = await self.measure_execution_time( f"Agent {selected_agent.name} | Processing request", @@ -110,9 +109,9 @@ async def route_request(self, session_id: str, additional_params: Dict[str, str] = {}) -> AgentResponse: self.execution_times.clear() - chat_history = await self.storage.fetch_all_chats(user_id, session_id) or [] try: + chat_history = await self.storage.fetch_all_chats(user_id, session_id) or [] classifier_result:ClassifierResult = await self.measure_execution_time( "Classifying user intent", lambda: self.classifier.classify(user_input, chat_history) @@ -210,13 +209,13 @@ async def route_request(self, def print_intent(self, user_input: str, intent_classifier_result: ClassifierResult) -> None: """Print the classified intent.""" - Logger.log_header('Classified Intent') - Logger.logger.info(f"> Text: {user_input}") - Logger.logger.info(f"> Selected Agent: {intent_classifier_result.selected_agent.name \ + self.logger.log_header('Classified Intent') + self.logger.info(f"> Text: {user_input}") + self.logger.info(f"> Selected Agent: {intent_classifier_result.selected_agent.name \ if intent_classifier_result.selected_agent \ else 'No agent selected'}") - Logger.logger.info(f"> Confidence: {intent_classifier_result.confidence:.2f}") - Logger.logger.info('') + self.logger.info(f"> Confidence: {intent_classifier_result.confidence:.2f}") + self.logger.info('') async def measure_execution_time(self, timer_name: str, fn): if not self.config.LOG_EXECUTION_TIMES: diff --git a/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py b/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py index 16e3b2e7..5cb3a011 100644 --- a/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py +++ b/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py @@ -57,7 +57,7 @@ async def save_chat_message( try: self.table.put_item(Item=item) except Exception as error: - Logger.logger.error(f"Error saving conversation to DynamoDB:{str(error)}") + Logger.error(f"Error saving conversation to DynamoDB:{str(error)}") raise error return self._remove_timestamps(trimmed_conversation) @@ -76,7 +76,7 @@ async def fetch_chat( ) return self._remove_timestamps(stored_messages) except Exception as error: - Logger.logger.error(f"Error getting conversation from DynamoDB:{str(error)}") + Logger.error(f"Error getting conversation from DynamoDB:{str(error)}") raise error async def fetch_chat_with_timestamp( @@ -93,7 +93,7 @@ async def fetch_chat_with_timestamp( ) return stored_messages except Exception as error: - Logger.logger.error(f"Error getting conversation from DynamoDB: {str(error)}") + Logger.error(f"Error getting conversation from DynamoDB: {str(error)}") raise error async def fetch_all_chats(self, user_id: str, session_id: str) -> List[ConversationMessage]: @@ -112,7 +112,7 @@ async def fetch_all_chats(self, user_id: str, session_id: str) -> List[Conversat all_chats = [] for item in response['Items']: if not isinstance(item.get('conversation'), list): - Logger.logger.error(f"Unexpected item structure:{item}") + Logger.error(f"Unexpected item structure:{item}") continue agent_id = item['SK'].split('#')[1] @@ -134,7 +134,7 @@ async def fetch_all_chats(self, user_id: str, session_id: str) -> List[Conversat all_chats.sort(key=lambda x: x.timestamp) return self._remove_timestamps(all_chats) except Exception as error: - Logger.logger.error(f"Error querying conversations from DynamoDB:{str(error)}") + Logger.error(f"Error querying conversations from DynamoDB:{str(error)}") raise error def _generate_key(self, user_id: str, session_id: str, agent_id: str) -> str: diff --git a/python/src/multi_agent_orchestrator/utils/logger.py b/python/src/multi_agent_orchestrator/utils/logger.py index e4631780..d3a7b0c7 100644 --- a/python/src/multi_agent_orchestrator/utils/logger.py +++ b/python/src/multi_agent_orchestrator/utils/logger.py @@ -6,41 +6,57 @@ logging.basicConfig(level=logging.INFO) class Logger: + _instance = None + _logger = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self, config: Optional[Dict[str, bool]] = None, - logger: Optional[logging.Logger] = logging.getLogger(__name__)): + logger: Optional[logging.Logger] = None): + if not hasattr(self, 'initialized'): + Logger._logger = logger or logging.getLogger(__name__) + self.initialized = True self.config: OrchestratorConfig = config or OrchestratorConfig() - self.set_logger(logger or logging.getLogger(__name__)) + + @classmethod + def get_logger(cls): + if cls._logger is None: + cls._logger = logging.getLogger(__name__) + return cls._logger @classmethod def set_logger(cls, logger: Any) -> None: - cls.logger = logger + cls._logger = logger @classmethod def info(cls, message: str, *args: Any) -> None: """Log an info message.""" - cls.logger.info(message, *args) + cls.get_logger().info(message, *args) @classmethod def warn(cls, message: str, *args: Any) -> None: """Log a warning message.""" - cls.logger.info(message, *args) + cls.get_logger().info(message, *args) @classmethod def error(cls, message: str, *args: Any) -> None: """Log an error message.""" - cls.logger.error(message, *args) + cls.get_logger().error(message, *args) @classmethod def debug(cls, message: str, *args: Any) -> None: """Log a debug message.""" - cls.logger.debug(message, *args) + cls.get_logger().debug(message, *args) @classmethod def log_header(cls, title: str) -> None: """Log a header with the given title.""" - cls.logger.info(f"\n** {title.upper()} **") - cls.logger.info('=' * (len(title) + 6)) + cls.get_logger().info(f"\n** {title.upper()} **") + cls.get_logger().info('=' * (len(title) + 6)) def print_chat_history(self, chat_history: List[ConversationMessage], @@ -52,10 +68,10 @@ def print_chat_history(self, return title = f"Agent {agent_id} Chat History" if is_agent_chat else 'Classifier Chat History' - Logger.log_header(title) + self.log_header(title) if not chat_history: - Logger.logger.info('> - None -') + self.get_logger().info('> - None -') else: for index, message in enumerate(chat_history, 1): role = message.role.upper() @@ -63,8 +79,8 @@ def print_chat_history(self, text = content[0] if isinstance(content, list) else content text = text.get('text', '') if isinstance(text, dict) else str(text) trimmed_text = f"{text[:80]}..." if len(text) > 80 else text - Logger.logger.info(f"> {index}. {role}: {trimmed_text}") - Logger.logger.info('') + self.get_logger().info(f"> {index}. {role}: {trimmed_text}") + self.get_logger().info('') def log_classifier_output(self, output: Any, is_raw: bool = False) -> None: """Log the classifier output.""" @@ -72,19 +88,19 @@ def log_classifier_output(self, output: Any, is_raw: bool = False) -> None: (not is_raw and not self.config.LOG_CLASSIFIER_OUTPUT): return - Logger.log_header('Raw Classifier Output' if is_raw else 'Processed Classifier Output') - Logger.logger.info(output if is_raw else json.dumps(output, indent=2)) - Logger.logger.info('') + self.log_header('Raw Classifier Output' if is_raw else 'Processed Classifier Output') + self.get_logger().info(output if is_raw else json.dumps(output, indent=2)) + self.get_logger().info('') def print_execution_times(self, execution_times: Dict[str, float]) -> None: """Print execution times.""" if not self.config.LOG_EXECUTION_TIMES: return - Logger.log_header('Execution Times') + self.log_header('Execution Times') if not execution_times: - Logger.logger.info('> - None -') + self.get_logger().info('> - None -') else: for timer_name, duration in execution_times.items(): - Logger.logger.info(f"> {timer_name}: {duration}s") - Logger.logger.info('') + self.get_logger().info(f"> {timer_name}: {duration}s") + self.get_logger().info('') diff --git a/python/src/tests/utils/test_logger.py b/python/src/tests/utils/test_logger.py index 270573fc..913e3bfb 100644 --- a/python/src/tests/utils/test_logger.py +++ b/python/src/tests/utils/test_logger.py @@ -15,16 +15,17 @@ def mock_logger(mocker): def test_logger_initialization(): logger = Logger() assert isinstance(logger.config, OrchestratorConfig) - assert isinstance(logger.logger, logging.Logger) + assert isinstance(logger._logger, logging.Logger) def test_logger_initialization_with_custom_config(): custom_config = OrchestratorConfig(**{'LOG_AGENT_CHAT': True, 'LOG_CLASSIFIER_CHAT': False}) + print(custom_config) logger = Logger(config=custom_config) assert logger.config == custom_config def test_set_logger(mock_logger): Logger.set_logger(mock_logger) - assert Logger.logger == mock_logger + assert Logger._logger == mock_logger @pytest.mark.parametrize("log_method", ["info", "info", "error", "debug"]) def test_log_methods(mock_logger, log_method):