diff --git a/lavague-core/lavague/core/agents.py b/lavague-core/lavague/core/agents.py index 04df3b67..0d3a4a83 100644 --- a/lavague-core/lavague/core/agents.py +++ b/lavague-core/lavague/core/agents.py @@ -22,7 +22,7 @@ from lavague.core.token_counter import TokenCounter from lavague.core.utilities.config import is_flag_true -from lavague.core.utilities.profiling import ChartGenerator, track_total_runtime, start_new_step +from lavague.core.utilities.profiling import ChartGenerator, profile_agent, start_new_step, clear_profiling_data logging_print = logging.getLogger(__name__) logging_print.setLevel(logging.INFO) @@ -448,7 +448,7 @@ def _run_step_gradio( self.process_token_usage() self.logger.end_step() - @track_total_runtime() + @profile_agent(event_type="RUN_STEP") def run_step(self, objective: str) -> Optional[ActionResult]: obs = self.driver.get_obs() current_state, past = self.st_memory.get_state() diff --git a/lavague-core/lavague/core/navigation.py b/lavague-core/lavague/core/navigation.py index bede4907..466253d9 100644 --- a/lavague-core/lavague/core/navigation.py +++ b/lavague-core/lavague/core/navigation.py @@ -24,7 +24,7 @@ from PIL import Image from llama_index.core.base.llms.base import BaseLLM from llama_index.core.embeddings import BaseEmbedding -from lavague.core.utilities.profiling import track_retriever, track_llm_call +from lavague.core.utilities.profiling import profile_agent NAVIGATION_ENGINE_PROMPT_TEMPLATE = ActionTemplate( """ @@ -144,7 +144,7 @@ def from_context( extractor, ) - @track_retriever() + @profile_agent(event_type="RETRIEVER_CALL") def get_nodes(self, query: str) -> List[str]: """ Get the nodes from the html page @@ -456,7 +456,7 @@ def execute_instruction(self, instruction: str) -> ActionResult: ) # response = self.llm.complete(prompt).text - response = track_llm_call("Navigation")(self.llm.complete)(prompt).text + response = profile_agent(event_type="LLM_CALL", event_name="Navigation Engine")(self.llm.complete)(prompt).text end = time.time() action_generation_time = end - start action_outcome = { @@ -480,6 +480,8 @@ def execute_instruction(self, instruction: str) -> ActionResult: for item in vision_data: display_screenshot(item["screenshot"]) time.sleep(0.2) + + profile_agent(event_type="DEFAULT", event_name="Execute code")(self.driver.exec_code)(action) self.driver.exec_code(action) time.sleep(self.time_between_actions) if self.display: diff --git a/lavague-core/lavague/core/world_model.py b/lavague-core/lavague/core/world_model.py index 0e3be5f2..c22d9729 100644 --- a/lavague-core/lavague/core/world_model.py +++ b/lavague-core/lavague/core/world_model.py @@ -11,7 +11,7 @@ from lavague.core.utilities.model_utils import get_model_name import time import yaml -from lavague.core.utilities.profiling import track_llm_call +from lavague.core.utilities.profiling import profile_agent WORLD_MODEL_GENERAL_EXAMPLES = """ Objective: Go to the first issue you can find @@ -433,7 +433,7 @@ def get_instruction( start = time.time() # decorated llm call - mm_llm_output = track_llm_call("World Model")(mm_llm.complete)( + mm_llm_output = profile_agent(event_type="LLM_CALL", event_name="World Model")(mm_llm.complete)( prompt, image_documents=image_documents ).text # mm_llm_output = mm_llm.complete(prompt, image_documents=image_documents).text