Skip to content

Commit

Permalink
Allow for custom metrics.
Browse files Browse the repository at this point in the history
Allow for passing auxiliary data from the run_episode function to the process_episodes. Also allows a custom process_episodes function to interpret this auxiliary data.

PiperOrigin-RevId: 697268080
  • Loading branch information
wichersn authored and The android_world Authors committed Nov 17, 2024
1 parent 653a029 commit 05ddb29
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 5 deletions.
5 changes: 3 additions & 2 deletions android_world/agents/m3a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""A Multimodal Autonomous Agent for Android (M3A)."""

import time
from android_world import constants
from android_world.agents import agent_utils
from android_world.agents import base_agent
from android_world.agents import infer
Expand Down Expand Up @@ -424,8 +425,8 @@ def step(self, goal: str) -> base_agent.AgentInteractionResult:

if is_safe == False: # pylint: disable=singleton-comparison
# is_safe could be None
action_output = """Reason: Triggered LLM safety classifier.
Action: {"action_type": "status", "goal_status": "infeasible"}"""
action_output = f"""Reason: {constants.TRIGGER_SAFETY_CLASSIFIER}
Action: {{"action_type": "status", "goal_status": "infeasible"}}"""

if not raw_response:
raise RuntimeError('Error calling LLM in action selection phase.')
Expand Down
5 changes: 3 additions & 2 deletions android_world/agents/t3a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""T3A: Text-only Autonomous Agent for Android."""

from android_world import constants
from android_world.agents import agent_utils
from android_world.agents import base_agent
from android_world.agents import infer
Expand Down Expand Up @@ -341,8 +342,8 @@ def step(self, goal: str) -> base_agent.AgentInteractionResult:

if is_safe == False: # pylint: disable=singleton-comparison
# is_safe could be None
action_output = """Reason: Triggered LLM safety classifier.
Action: {"action_type": "status", "goal_status": "infeasible"}"""
action_output = f"""Reason: {constants.TRIGGER_SAFETY_CLASSIFIER}
Action: {{"action_type": "status", "goal_status": "infeasible"}}"""

if not raw_response:
raise RuntimeError('Error calling LLM in action selection phase.')
Expand Down
4 changes: 4 additions & 0 deletions android_world/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# The current step number in a given episode.
STEP_NUMBER = 'step_number'

TRIGGER_SAFETY_CLASSIFIER = 'Triggered LLM safety classifier.'

class EpisodeConstants:
"""Episode-level constants when recording agents performing automation tasks.
Expand All @@ -33,6 +34,8 @@ class EpisodeConstants:
EPISODE_LENGTH: The length of the episode.
FINISH_DTIME: The datetime the task finished.
SEED: The random seed to initialize the current episode's task.
AUX_DATA: Additional data which can be passed from the task to
process_episodes.
"""

EPISODE_DATA = 'episode_data'
Expand All @@ -48,3 +51,4 @@ class EpisodeConstants:
EXCEPTION_INFO = 'exception_info'
FINISH_DTIME = 'finish_dtime'
SEED = 'seed'
AUX_DATA = 'aux_data'
1 change: 1 addition & 0 deletions android_world/episode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class EpisodeResult:
done: bool
step_data: dict[str, Any]
env_reward: Optional[float] = None
aux_data: Optional[dict[str, Any]] = None


def run_episode(
Expand Down
11 changes: 10 additions & 1 deletion android_world/suite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _run_task(
constants.EpisodeConstants.EPISODE_LENGTH: len(
interaction_results.step_data[constants.STEP_NUMBER]
),
constants.EpisodeConstants.AUX_DATA: interaction_results.aux_data,
constants.EpisodeConstants.SCREEN_CONFIG: _get_screen_config(task),
constants.EpisodeConstants.EXCEPTION_INFO: None,
constants.EpisodeConstants.SEED: task.params[
Expand Down Expand Up @@ -309,6 +310,7 @@ def _run_task_suite(
demo_mode: bool = False,
agent_name: str = '',
return_full_episode_data: bool = False,
process_episodes_fn=None,
) -> list[dict[str, Any]]:
"""Runs e2e system on suite.
Expand All @@ -333,10 +335,14 @@ def _run_task_suite(
constants.EpisodeConstants.EPISODE_LENGTH,
constants.EpisodeConstants.RUN_TIME,
constants.EpisodeConstants.EXCEPTION_INFO,
constants.EpisodeConstants.AUX_DATA,
]
completed_tasks, failed_tasks = _get_task_info(
checkpointer.load(fields=metadata_fields)
)
if process_episodes_fn is None:
process_episodes_fn = process_episodes

if (completed_tasks or failed_tasks) and return_full_episode_data:
raise ValueError(
'Cannot return full episode data when resuming from a checkpoint.'
Expand Down Expand Up @@ -376,7 +382,7 @@ def _run_task_suite(
full_episode_data.append(episode)

episodes_metadata.append({k: episode[k] for k in metadata_fields})
process_episodes(episodes_metadata, print_summary=True)
process_episodes_fn(episodes_metadata, print_summary=True)

if episode[constants.EpisodeConstants.EXCEPTION_INFO] is not None:
# Don't include episode in tally if execution/eval logic errored out.
Expand All @@ -396,6 +402,7 @@ def run(
checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(),
demo_mode: bool = False,
return_full_episode_data: bool = False,
process_episodes_fn=None,
) -> list[dict[str, Any]]:
"""Create suite and runs eval suite.
Expand Down Expand Up @@ -446,6 +453,7 @@ def run_episode(task: task_eval.TaskEval) -> episode_runner.EpisodeResult:
demo_mode=demo_mode,
agent_name=agent.name,
return_full_episode_data=return_full_episode_data,
process_episodes_fn=process_episodes_fn,
)

return results
Expand Down Expand Up @@ -517,6 +525,7 @@ def _create_failed_result(
constants.EpisodeConstants.RUN_TIME: run_time,
constants.EpisodeConstants.EPISODE_LENGTH: np.nan,
constants.EpisodeConstants.EXCEPTION_INFO: exception,
constants.EpisodeConstants.AUX_DATA: None,
}


Expand Down

0 comments on commit 05ddb29

Please sign in to comment.