From 0df6a69b3e0694e609ec68a68b7c23f0c92968d4 Mon Sep 17 00:00:00 2001 From: Nevan Wichers Date: Sat, 16 Nov 2024 19:15:02 -0800 Subject: [PATCH] Allow for custom metrics. Allow for passing auxiliary data from the run_episode function to the process_episodes. Also allows a custom process_episodes function to interpret this auxiliary data. PiperOrigin-RevId: 697268080 --- android_world/constants.py | 3 +++ android_world/episode_runner.py | 2 ++ android_world/suite_utils.py | 15 ++++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/android_world/constants.py b/android_world/constants.py index be91711..783bf93 100644 --- a/android_world/constants.py +++ b/android_world/constants.py @@ -33,6 +33,8 @@ class EpisodeConstants: EPISODE_LENGTH: The length of the episode. FINISH_DTIME: The datetime the task finished. SEED: The random seed to initialize the current episode's task. + AUX_DATA: Additional data which can be passed from the task to + process_episodes. """ EPISODE_DATA = 'episode_data' @@ -48,3 +50,4 @@ class EpisodeConstants: EXCEPTION_INFO = 'exception_info' FINISH_DTIME = 'finish_dtime' SEED = 'seed' + AUX_DATA = 'aux_data' diff --git a/android_world/episode_runner.py b/android_world/episode_runner.py index a6df1dc..eabdbc0 100644 --- a/android_world/episode_runner.py +++ b/android_world/episode_runner.py @@ -31,11 +31,13 @@ class EpisodeResult: done: Whether the agent indicated the task is complete. step_data: Environment and agent data for each step. env_reward: Reward returned by environment, if applicable. + aux_data: Additional data from the episode which may be used for metrics. """ done: bool step_data: dict[str, Any] env_reward: Optional[float] = None + aux_data: Optional[dict[str, Any]] = None def run_episode( diff --git a/android_world/suite_utils.py b/android_world/suite_utils.py index 53d4cf5..ffac508 100644 --- a/android_world/suite_utils.py +++ b/android_world/suite_utils.py @@ -264,6 +264,7 @@ def _run_task( constants.EpisodeConstants.EPISODE_LENGTH: len( interaction_results.step_data[constants.STEP_NUMBER] ), + constants.EpisodeConstants.AUX_DATA: interaction_results.aux_data, constants.EpisodeConstants.SCREEN_CONFIG: _get_screen_config(task), constants.EpisodeConstants.EXCEPTION_INFO: None, constants.EpisodeConstants.SEED: task.params[ @@ -309,6 +310,7 @@ def _run_task_suite( demo_mode: bool = False, agent_name: str = '', return_full_episode_data: bool = False, + process_episodes_fn=None, ) -> list[dict[str, Any]]: """Runs e2e system on suite. @@ -321,6 +323,8 @@ def _run_task_suite( agent_name: The name of the agent. return_full_episode_data: Whether to return full episode data instead of just metadata. + process_episodes_fn: The function to process episode data. Usually to + compute metrics. Deafaults to process_episodes from this file. Returns: Metadata for each episode, including the scripted reward. @@ -333,10 +337,14 @@ def _run_task_suite( constants.EpisodeConstants.EPISODE_LENGTH, constants.EpisodeConstants.RUN_TIME, constants.EpisodeConstants.EXCEPTION_INFO, + constants.EpisodeConstants.AUX_DATA, ] completed_tasks, failed_tasks = _get_task_info( checkpointer.load(fields=metadata_fields) ) + if process_episodes_fn is None: + process_episodes_fn = process_episodes + if (completed_tasks or failed_tasks) and return_full_episode_data: raise ValueError( 'Cannot return full episode data when resuming from a checkpoint.' @@ -376,7 +384,7 @@ def _run_task_suite( full_episode_data.append(episode) episodes_metadata.append({k: episode[k] for k in metadata_fields}) - process_episodes(episodes_metadata, print_summary=True) + process_episodes_fn(episodes_metadata, print_summary=True) if episode[constants.EpisodeConstants.EXCEPTION_INFO] is not None: # Don't include episode in tally if execution/eval logic errored out. @@ -396,6 +404,7 @@ def run( checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(), demo_mode: bool = False, return_full_episode_data: bool = False, + process_episodes_fn=None, ) -> list[dict[str, Any]]: """Create suite and runs eval suite. @@ -410,6 +419,8 @@ def run( task instruction as a notification. return_full_episode_data: Whether to return full episode data instead of just metadata. + process_episodes_fn: The function to process episode data. Usually to + compute metrics. Deafaults to process_episodes from this file. Returns: Step-by-step data from each episode. @@ -446,6 +457,7 @@ def run_episode(task: task_eval.TaskEval) -> episode_runner.EpisodeResult: demo_mode=demo_mode, agent_name=agent.name, return_full_episode_data=return_full_episode_data, + process_episodes_fn=process_episodes_fn, ) return results @@ -517,6 +529,7 @@ def _create_failed_result( constants.EpisodeConstants.RUN_TIME: run_time, constants.EpisodeConstants.EPISODE_LENGTH: np.nan, constants.EpisodeConstants.EXCEPTION_INFO: exception, + constants.EpisodeConstants.AUX_DATA: None, }