Skip to content

Commit

Permalink
Allow for custom metrics.
Browse files Browse the repository at this point in the history
Allow for passing auxiliary data from the run_episode function to the process_episodes. Also allows a custom process_episodes function to interpret this auxiliary data.

PiperOrigin-RevId: 697268080
  • Loading branch information
wichersn authored and The android_world Authors committed Nov 26, 2024
1 parent fd13643 commit 0df6a69
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
3 changes: 3 additions & 0 deletions android_world/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class EpisodeConstants:
EPISODE_LENGTH: The length of the episode.
FINISH_DTIME: The datetime the task finished.
SEED: The random seed to initialize the current episode's task.
AUX_DATA: Additional data which can be passed from the task to
process_episodes.
"""

EPISODE_DATA = 'episode_data'
Expand All @@ -48,3 +50,4 @@ class EpisodeConstants:
EXCEPTION_INFO = 'exception_info'
FINISH_DTIME = 'finish_dtime'
SEED = 'seed'
AUX_DATA = 'aux_data'
2 changes: 2 additions & 0 deletions android_world/episode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class EpisodeResult:
done: Whether the agent indicated the task is complete.
step_data: Environment and agent data for each step.
env_reward: Reward returned by environment, if applicable.
aux_data: Additional data from the episode which may be used for metrics.
"""

done: bool
step_data: dict[str, Any]
env_reward: Optional[float] = None
aux_data: Optional[dict[str, Any]] = None


def run_episode(
Expand Down
15 changes: 14 additions & 1 deletion android_world/suite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _run_task(
constants.EpisodeConstants.EPISODE_LENGTH: len(
interaction_results.step_data[constants.STEP_NUMBER]
),
constants.EpisodeConstants.AUX_DATA: interaction_results.aux_data,
constants.EpisodeConstants.SCREEN_CONFIG: _get_screen_config(task),
constants.EpisodeConstants.EXCEPTION_INFO: None,
constants.EpisodeConstants.SEED: task.params[
Expand Down Expand Up @@ -309,6 +310,7 @@ def _run_task_suite(
demo_mode: bool = False,
agent_name: str = '',
return_full_episode_data: bool = False,
process_episodes_fn=None,
) -> list[dict[str, Any]]:
"""Runs e2e system on suite.
Expand All @@ -321,6 +323,8 @@ def _run_task_suite(
agent_name: The name of the agent.
return_full_episode_data: Whether to return full episode data instead of
just metadata.
process_episodes_fn: The function to process episode data. Usually to
compute metrics. Deafaults to process_episodes from this file.
Returns:
Metadata for each episode, including the scripted reward.
Expand All @@ -333,10 +337,14 @@ def _run_task_suite(
constants.EpisodeConstants.EPISODE_LENGTH,
constants.EpisodeConstants.RUN_TIME,
constants.EpisodeConstants.EXCEPTION_INFO,
constants.EpisodeConstants.AUX_DATA,
]
completed_tasks, failed_tasks = _get_task_info(
checkpointer.load(fields=metadata_fields)
)
if process_episodes_fn is None:
process_episodes_fn = process_episodes

if (completed_tasks or failed_tasks) and return_full_episode_data:
raise ValueError(
'Cannot return full episode data when resuming from a checkpoint.'
Expand Down Expand Up @@ -376,7 +384,7 @@ def _run_task_suite(
full_episode_data.append(episode)

episodes_metadata.append({k: episode[k] for k in metadata_fields})
process_episodes(episodes_metadata, print_summary=True)
process_episodes_fn(episodes_metadata, print_summary=True)

if episode[constants.EpisodeConstants.EXCEPTION_INFO] is not None:
# Don't include episode in tally if execution/eval logic errored out.
Expand All @@ -396,6 +404,7 @@ def run(
checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(),
demo_mode: bool = False,
return_full_episode_data: bool = False,
process_episodes_fn=None,
) -> list[dict[str, Any]]:
"""Create suite and runs eval suite.
Expand All @@ -410,6 +419,8 @@ def run(
task instruction as a notification.
return_full_episode_data: Whether to return full episode data instead of
just metadata.
process_episodes_fn: The function to process episode data. Usually to
compute metrics. Deafaults to process_episodes from this file.
Returns:
Step-by-step data from each episode.
Expand Down Expand Up @@ -446,6 +457,7 @@ def run_episode(task: task_eval.TaskEval) -> episode_runner.EpisodeResult:
demo_mode=demo_mode,
agent_name=agent.name,
return_full_episode_data=return_full_episode_data,
process_episodes_fn=process_episodes_fn,
)

return results
Expand Down Expand Up @@ -517,6 +529,7 @@ def _create_failed_result(
constants.EpisodeConstants.RUN_TIME: run_time,
constants.EpisodeConstants.EPISODE_LENGTH: np.nan,
constants.EpisodeConstants.EXCEPTION_INFO: exception,
constants.EpisodeConstants.AUX_DATA: None,
}


Expand Down

0 comments on commit 0df6a69

Please sign in to comment.