From be7ae98c1fb9d0c5292874b0622488e8a3a5a4c8 Mon Sep 17 00:00:00 2001 From: Joseph Doyle Date: Mon, 7 Mar 2022 19:15:34 -0500 Subject: [PATCH 1/9] initial commit with pseudocode --- examples/QTable/model/new_runner.py | 80 +++++++++++++ .../src/mivp_agent/episodic_manager.py | 108 ++++++++++++++++++ 2 files changed, 188 insertions(+) create mode 100755 examples/QTable/model/new_runner.py create mode 100644 src/python_module/src/mivp_agent/episodic_manager.py diff --git a/examples/QTable/model/new_runner.py b/examples/QTable/model/new_runner.py new file mode 100755 index 0000000..e3552ad --- /dev/null +++ b/examples/QTable/model/new_runner.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +import argparse +import time + +from mivp_agent.episodic_manager import EpisodicManager +from mivp_agent.util.math import dist +from mivp_agent.util.display import ModelConsole +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG + +from constants import DEFAULT_RUN_MODEL +from model import load_model + + +def run(args): + q, attack_actions, retreat_actions = load_model(args.model) + + with EpisodicManager('runner', log=False) as mgr: + print('Waiting for sim vehicle connections...') + while mgr.get_vehicle_count() < 1: + time.sleep(0.1) + # --------------------------------------- + # Part 1: Asserting simulation state + + last_state = None + current_action = None + current_action_set = None + console = ModelConsole() + + while True: + # Listen for state + msg = mgr.get_message() + while False: + print('-------------------------------------------') + print(f"({msg.vname}) {msg.state['HAS_FLAG']}") + print('-------------------------------------------') + msg.request_new() + msg = mgr.get_message() + + console.tick(msg) + + # Detect state transitions + model_state = q.get_state( + msg.state['NAV_X'], + msg.state['NAV_Y'], + msg.state['NODE_REPORTS'][args.enemy]['NAV_X'], + msg.state['NODE_REPORTS'][args.enemy]['NAV_Y'], + msg.state['HAS_FLAG'] + ) + + # Detect state transition + if model_state != last_state: + current_action = q.get_action(model_state) + last_state = model_state + + # Determine action set + if msg.state['HAS_FLAG']: + current_action_set = retreat_actions + else: + current_action_set = attack_actions + + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[current_action]['speed'], + 'course': current_action_set[current_action]['course'] + } + + flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' + } + + msg.act(action) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default=DEFAULT_RUN_MODEL) + parser.add_argument('--enemy', default='drone_21') + args = parser.parse_args() + run(args) \ No newline at end of file diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py new file mode 100644 index 0000000..4211bf7 --- /dev/null +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -0,0 +1,108 @@ +# +# ---------------------------------------------- +# +# - wait_for might not be needed as a param to Episodic init + + +class AgentData: + def __init__(self, vname) -> None: + self.vname = vname + self.last_episode = None + self.last_state = None + + self.current_action = None + +class EpisodicManager: + def __init__(self, agents, episodes, wait_for=None) -> None: + ''' + SETUP AGENTS + ''' + self.agents = agents + self.agent_data = {} + + # Combine the agent's id and any additional wait fors + self.wait_for = wait_for + + if self.wait_for is None: + self.wait_for = [] + + # Setup things specific to each agent + for a in self.agents: + # Make sure they are waited for + self.wait_for.append(a.id()) + + # Make a agent state for them + self.agent_data[a.id()] = AgentData(a.id()) + + ''' + SETUP EPISODE TRACKING + ''' + self.episodes = episodes + self.current_episode = 0 + + + def run(self, task, log=True): + with MissionManager(task, log=log) as mgr: + mgr.wait_for(self.wait_for) + + while self.current_episode < self.episodes: + msg = msg.get_message() + + # Find agent in list... + for a in self.agents: + if msg.vname == a.id(): + data = self.agent_data[a.id()] + + # Probably always + state = a.obs_to_state(msg.state) + + if data.last_state != state: + msg.transition() + data.current_action = a.state_to_act(state) + + # Update episode count if applicable + if data.last_episode != msg.episode_report['NUM']: + data.last_episode = msg.episode_report['NUM'] + + ################################################ + # Importantly, update the global episode count # + self.episodes += 1 + + # track data + data.last_state = state + + ################################################ + # Importantly, actually do shit # + msg.act(data.current_action) + + +class Agent: + def __init__(self, own_id, opponent_id) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + + def id(self): + return self.own_id + + def obs_to_state(self, observation): + # obs -> state + state = observation + 1 + + return state + + def state_to_act(self, state): + act = my_model(state) + + return act + +if __name__ == '__main__': + # Create agents required + + agents = [] + wait_for = [] + for i in [1,2,3,4,5]: + agents.append(Agent(f'agent_1{i}', f'drone_2{i}')) + wait_for.append(f'drone_2') + + mgr = EpisodicManager(agents, 1000, wait_for=wait_for) + mgr.run() \ No newline at end of file From e57ed8552135ad8b3e7db75c4ec00459539c94ab Mon Sep 17 00:00:00 2001 From: Joseph Doyle Date: Thu, 10 Mar 2022 15:51:14 -0500 Subject: [PATCH 2/9] this version basically just adds run() to MissionManager and renames to EpisodicManager --- examples/QTable/model/new_runner.py | 68 +-- examples/QTable/run.sh | 2 +- .../src/mivp_agent/episodic_manager.py | 460 +++++++++++++++--- 3 files changed, 383 insertions(+), 147 deletions(-) diff --git a/examples/QTable/model/new_runner.py b/examples/QTable/model/new_runner.py index e3552ad..670bdc6 100755 --- a/examples/QTable/model/new_runner.py +++ b/examples/QTable/model/new_runner.py @@ -3,78 +3,16 @@ import time from mivp_agent.episodic_manager import EpisodicManager -from mivp_agent.util.math import dist -from mivp_agent.util.display import ModelConsole -from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG from constants import DEFAULT_RUN_MODEL from model import load_model -def run(args): - q, attack_actions, retreat_actions = load_model(args.model) - - with EpisodicManager('runner', log=False) as mgr: - print('Waiting for sim vehicle connections...') - while mgr.get_vehicle_count() < 1: - time.sleep(0.1) - # --------------------------------------- - # Part 1: Asserting simulation state - - last_state = None - current_action = None - current_action_set = None - console = ModelConsole() - - while True: - # Listen for state - msg = mgr.get_message() - while False: - print('-------------------------------------------') - print(f"({msg.vname}) {msg.state['HAS_FLAG']}") - print('-------------------------------------------') - msg.request_new() - msg = mgr.get_message() - - console.tick(msg) - - # Detect state transitions - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_X'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_Y'], - msg.state['HAS_FLAG'] - ) - - # Detect state transition - if model_state != last_state: - current_action = q.get_action(model_state) - last_state = model_state - - # Determine action set - if msg.state['HAS_FLAG']: - current_action_set = retreat_actions - else: - current_action_set = attack_actions - - # Construct instruction for BHV_Agent - action = { - 'speed': current_action_set[current_action]['speed'], - 'course': current_action_set[current_action]['course'] - } - - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - if flag_dist < 10: - action['posts']= { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } - - msg.act(action) - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', default=DEFAULT_RUN_MODEL) parser.add_argument('--enemy', default='drone_21') args = parser.parse_args() - run(args) \ No newline at end of file + q, attack_actions, retreat_actions = load_model(args.model) + with EpisodicManager('runner', log=False) as mgr: + mgr.run(q, attack_actions, retreat_actions) diff --git a/examples/QTable/run.sh b/examples/QTable/run.sh index e78a8a4..fa5ebd8 100755 --- a/examples/QTable/run.sh +++ b/examples/QTable/run.sh @@ -32,6 +32,6 @@ sleep 5 # Start trainer with any arguments passed to this script echo "Launching runner..." -./model/runner.py "$@" +./model/new_runner.py "$@" cd $PREVIOUS_WD \ No newline at end of file diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index 4211bf7..f561390 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -1,108 +1,406 @@ -# -# ---------------------------------------------- -# -# - wait_for might not be needed as a param to Episodic init +# General +import os +import time +from queue import Queue, Empty +from threading import Thread, Lock +# For core +from mivp_agent.const import KEY_ID, DATA_DIRECTORY +from mivp_agent.messages import MissionMessage, INSTR_SEND_STATE, INSTR_RESET_FAILURE, INSTR_RESET_SUCCESS +from mivp_agent.bridge import ModelBridgeServer +from mivp_agent.util.display import ModelConsole +from mivp_agent.util.math import dist +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG -class AgentData: - def __init__(self, vname) -> None: - self.vname = vname - self.last_episode = None - self.last_state = None - self.current_action = None +# For logging +from mivp_agent.log.directory import LogDirectory +from mivp_agent.proto.proto_logger import ProtoLogger +from mivp_agent.proto.mivp_agent_pb2 import Transition +from mivp_agent.proto import translate class EpisodicManager: - def __init__(self, agents, episodes, wait_for=None) -> None: ''' - SETUP AGENTS + This is the primary method for interfacing with moos-ivp-agent's BHV_Agent + + Examples: + It is recommended to use EpisodicManager with the python context manager + + ``` + from mivp_agent.manager import EpisodicManager + + with EpisodicManager('trainer') as mgr: + mgr.wait_for(['felix', 'evan']) + ... + ``` ''' - self.agents = agents - self.agent_data = {} - - # Combine the agent's id and any additional wait fors - self.wait_for = wait_for - if self.wait_for is None: - self.wait_for = [] - - # Setup things specific to each agent - for a in self.agents: - # Make sure they are waited for - self.wait_for.append(a.id()) + def __init__(self, task, log=True, immediate_transition=True, log_whitelist=None, id_suffix=None, output_dir=None): + ''' + The initializer for EpisodicManager + + Args: + task (str): For organization of saved data type is required to specify what task the EpisodicManager is preforming. For example a `EpisodicManager('trainer')` will log data under `generated_files/trainer/` in the current working directory. + + log (bool): Logging of agent transitions can be disabled by setting this to `False`. + + immediate_transition (bool): By default the the manager will assume that all messages received from BHV_Agents represent a new transition. If set to `False` one must manually tell set `msg.is_transition = True` on any objects returned from `get_message()`. This is helpful when you want to control what is considered a "state" in your Markov Decsion Process. + + log_whitelist (list): Setting this parameter will only log some transitions according to their reported `vnames`. + + id_suffix (str): Will be appended to the generated session id. + + output_dir (str): Path to a place to store files. + ''' + self._msg_queue = Queue() + + self._vnames = [] + self._vname_lock = Lock() + self._vehicle_count = 0 + self._episode_manager_states = {} + self._ems_lock = Lock() + self._episode_manager_nums = {} + self._emn_lock = Lock() + + # Dict to hold queues of vnames to reset + self._vresets = Queue() + + self._thread = None + self._stop_signal = False + + if output_dir is None: + output_dir = os.path.join( + os.path.abspath(os.getcwd()), + DATA_DIRECTORY + ) + + self._log_dir = LogDirectory(output_dir) + self._id = self._init_session(id_suffix) + + # Calculate the path to the directories we will be writing to. This will be created when we first use them / return them to the user. + self._model_path = os.path.join( + self._log_dir.models_dir(), + self._id + ) + self._log_path = os.path.join( + self._log_dir.task_dir(task), + self._id + ) + self._model_path = os.path.abspath(self._model_path) + self._log_path = os.path.abspath(self._log_path) - # Make a agent state for them - self.agent_data[a.id()] = AgentData(a.id()) + self._log = log + self._imm_transition = immediate_transition + if self._log: + self._log_whitelist = log_whitelist + # Create data structs needed to log data from each vehicle + self._logs = {} + self._last_state = {} + self._last_act = {} + + # Go ahead and create the log path + os.makedirs(self._log_path) + + def _init_session(self, id_suffix): + # Start the session id with the current timestamp + id = str(round(time.time())) + + # Add suffix if it exists + if id_suffix is not None: + id += f"-{id_suffix}" + + id = self._log_dir.meta.registry.register(id) + + return id + + def model_output_dir(self): + if not os.path.isdir(self._model_path): + os.makedirs(self._model_path) + return self._model_path - ''' - SETUP EPISODE TRACKING - ''' - self.episodes = episodes - self.current_episode = 0 + def log_output_dir(self): + assert self._log, "This method should not be used, when logging is disabled" + return self._log_path + + def __enter__(self): + self.start() + return self + + def start(self): + ''' + It is **not recommended** to use this method directly. Instead, consider using this class with the python context manager. This method starts a thread to read from the `ModelBridgeServer`. + + Returns: + bool: False if thread has already been started, True otherwise + ''' + if self._thread is not None: + return False + + self._thread = Thread(target=self._server_thread, daemon=True) + self._thread.start() + + return True + + def _server_thread(self): + live_msg_list = [] + address_map = {} + with ModelBridgeServer() as server: + while not self._stop_signal: + # Accept new clients + addr = server.accept() + if addr is not None: + print(f'Got new connection: {addr}') + server.send_instr(addr, INSTR_SEND_STATE) + + # Listen for messages from vehicles + for addr in server._clients: + msg = server.listen(addr) + if msg is not None: + with self._vname_lock: + if msg[KEY_ID] not in self._vnames: + print(f'Got new vehicle: {msg[KEY_ID]}') + vname = msg[KEY_ID] + address_map[vname] = addr + self._vnames.append(vname) + self._vehicle_count += 1 - def run(self, task, log=True): - with MissionManager(task, log=log) as mgr: - mgr.wait_for(self.wait_for) + assert address_map[msg[KEY_ID]] == addr, "Vehicle changed vname. This violates routing / logging assumptions made by EpisodicManager" - while self.current_episode < self.episodes: - msg = msg.get_message() + m = MissionMessage( + addr, + msg, + is_transition=self._imm_transition + ) - # Find agent in list... - for a in self.agents: - if msg.vname == a.id(): - data = self.agent_data[a.id()] + with self._ems_lock: + self._episode_manager_states[m.vname] = m.episode_state + with self._emn_lock: + if m.episode_report is None: + self._episode_manager_nums[m.vname] = None + else: + self._episode_manager_nums[m.vname] = m.episode_report['NUM'] - # Probably always - state = a.obs_to_state(msg.state) + live_msg_list.append(m) + self._msg_queue.put(m) - if data.last_state != state: - msg.transition() - data.current_action = a.state_to_act(state) + # Send responses to vehicle message if there are any + for i, m in enumerate(live_msg_list): + with m._rsp_lock: + if m._response is None: + continue - # Update episode count if applicable - if data.last_episode != msg.episode_report['NUM']: - data.last_episode = msg.episode_report['NUM'] - - ################################################ - # Importantly, update the global episode count # - self.episodes += 1 + # If we got there is response send and remove from list + live_msg_list.remove(m) + server.send_instr(m._addr, m._response) - # track data - data.last_state = state + # Do logging + self._do_logging(m) - ################################################ - # Importantly, actually do shit # - msg.act(data.current_action) + # Handle reseting of vehicles + while not self._vresets.empty(): + vname, success = self._vresets.get() + if vname not in address_map: + raise RuntimeError( + f'Received reset for unknown vehicle: {vname}') -class Agent: - def __init__(self, own_id, opponent_id) -> None: - self.own_id = own_id - self.opponent_id = opponent_id + instr = INSTR_RESET_FAILURE + if success: + instr = INSTR_RESET_SUCCESS - def id(self): - return self.own_id + server.send_instr(address_map[vname], instr) - def obs_to_state(self, observation): - # obs -> state - state = observation + 1 + # This message should only be called on msgs which have actions + def _do_logging(self, msg): + if not self._log: + return + + # Check in whitelist if exists + if self._log_whitelist is not None: + if msg.vname not in self._log_whitelist: + return - return state - - def state_to_act(self, state): - act = my_model(state) + # Check if this is a new vehicle + if msg.vname not in self._logs: + path = os.path.join(self._log_path, f"log_{msg.vname}") + self._logs[msg.vname] = ProtoLogger(path, Transition, mode='w') + + if msg._is_transition: + # Write a transition if this is not the first state ever + if msg.vname in self._last_state: + t = Transition() + t.s1.CopyFrom(translate.state_from_dict(self._last_state[msg.vname])) + t.a.CopyFrom(translate.action_from_dict(self._last_act[msg.vname])) + t.s2.CopyFrom(translate.state_from_dict(msg.state)) + + self._logs[msg.vname].write(t) + + # Update the storage for next transition + self._last_state[msg.vname] = msg.state + self._last_act[msg.vname] = msg._response + + def are_present(self, vnames): + ''' + Used to see if a specified list of vehicles has connected to the `EpisodicManager` instance yet. + + See also: [`wait_for()`][mivp_agent.manager.EpisodicManager.wait_for] + + Args: + vnames (iterable): A list / tuple of `str` values to look for + ''' + for vname in vnames: + with self._vname_lock: + if vname not in self._vnames: + return False + return True + + def wait_for(self, vnames, sleep=0.1): + ''' + Used to block until a specified list of vehicles has connect to the `EpisodicManager` instance. + + Args: + vnames (iterable): A list / tuple of `str` values to look for + sleep (float): Amount of time in seconds to sleep for between checks + ''' + while not self.are_present(vnames): + time.sleep(sleep) + + def get_message(self, block=True): + ''' + Used as the primary method for receiving data from `BHV_Agent`. + + **NOTE:** Messages **MUST** be responded to as `BHV_Agent` will not send another update until it has a response to the last. + + Args: + block (bool): A boolean specifying if the method will wait until a message present or return immediately + + Returns: + obj: A instance of [`MissionMessage()`][mivp_agent.manager.MissionMessage] or `None` depending on the blocking behavior + + Example: + ``` + msg = mgr.get_message() + + NAV_X = msg.state['NAV_X'] + NAV_Y = msg.state['NAV_Y'] + + # ... + # Some processing + # ... + + msg.act({ + 'speed': 1.0 + 'course': 180.0 + }) + ``` + ''' + try: + return self._msg_queue.get(block=block) + except Empty: + return None + + def get_vehicle_count(self): + ''' + Returns: + int: The amount of vehicles that have connected to this instance of `EpisodicManager` + ''' + return self._vehicle_count + + def episode_state(self, vname): + ''' + This is used to interrogate the state of a connected vehicle's `pEpisodeManager` + + Args: + vname (str): the vname of the vehicle + + Returns: + str: The state of the `pEpisodeManager` on the vehicle + ''' + with self._ems_lock: + # Should be all strings so no reference odd ness + return self._episode_manager_states[vname] + + def episode_nums(self): + ''' + Returns: + dict: A key, value pair maping vnames to the episode numbers of the `pEpisodeManager` app on that vehicle + ''' + with self._emn_lock: + return self._episode_manager_nums.copy() + + def reset_vehicle(self, vname, success=False): + # Untested + self._vresets.append((vname, success)) + + def run(self, q, attack_actions, retreat_actions): + #q, attack_actions, retreat_actions = load_model(args.model) + + print('Waiting for sim vehicle connections...') + while self.get_vehicle_count() < 1: + time.sleep(0.1) + # --------------------------------------- + # Part 1: Asserting simulation state + + last_state = None + current_action = None + current_action_set = None + console = ModelConsole() + + while True: + # Listen for state + msg = self.get_message() + while False: + print('-------------------------------------------') + print(f"({msg.vname}) {msg.state['HAS_FLAG']}") + print('-------------------------------------------') + msg.request_new() + msg = mgr.get_message() + + console.tick(msg) + + # Detect state transitions + model_state = q.get_state( + msg.state['NAV_X'], + msg.state['NAV_Y'], + msg.state['NODE_REPORTS']['drone_21']['NAV_X'], + msg.state['NODE_REPORTS']['drone_21']['NAV_Y'], + msg.state['HAS_FLAG'] + ) + + # Detect state transition + if model_state != last_state: + current_action = q.get_action(model_state) + last_state = model_state + + # Determine action set + if msg.state['HAS_FLAG']: + current_action_set = retreat_actions + else: + current_action_set = attack_actions + + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[current_action]['speed'], + 'course': current_action_set[current_action]['course'] + } + + flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' + } + + msg.act(action) - return act + def close(self): + if self._thread is not None: + self._stop_signal = True + self._thread.join() + if self._log: + for vehicle in self._logs: + self._logs[vehicle].close() -if __name__ == '__main__': - # Create agents required - agents = [] - wait_for = [] - for i in [1,2,3,4,5]: - agents.append(Agent(f'agent_1{i}', f'drone_2{i}')) - wait_for.append(f'drone_2') - - mgr = EpisodicManager(agents, 1000, wait_for=wait_for) - mgr.run() \ No newline at end of file + def __exit__(self, exc_type, exc_value, traceback): + self.close() From 20408befaf7841f63fb64c2bba04aed0e8d58f89 Mon Sep 17 00:00:00 2001 From: Joseph Doyle Date: Sun, 20 Mar 2022 01:46:28 -0400 Subject: [PATCH 3/9] barebones functional episodic_manager --- examples/QTable/model/new_runner.py | 77 ++- examples/QTable/run.sh | 2 +- examples/QTable/scripts/gui_launch.sh | 8 + .../src/mivp_agent/episodic_manager.py | 446 +++--------------- 4 files changed, 144 insertions(+), 389 deletions(-) diff --git a/examples/QTable/model/new_runner.py b/examples/QTable/model/new_runner.py index 670bdc6..874b52c 100755 --- a/examples/QTable/model/new_runner.py +++ b/examples/QTable/model/new_runner.py @@ -3,16 +3,83 @@ import time from mivp_agent.episodic_manager import EpisodicManager - +from mivp_agent.util.math import dist +from mivp_agent.util.display import ModelConsole +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG from constants import DEFAULT_RUN_MODEL from model import load_model +class Agent: + def __init__(self, own_id, opponent_id, model) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q, self.attack_actions, self.retreat_actions = load_model(model) + self.current_action = None + + def id(self): + return self.own_id + + ''' + Idea of the method bellow... + Trainer was previously had to deal with state transitions itself. + msg1 - Message from v1 at time 1 + msg2 - Message from v1 at time 2 + # We can find when transitions by doing this + obs_to_rpr(msg1.state) != obs_to_rpr(msg2.state) + previous_state != obs_to_rpr(blah blah) + ''' + #rename to msg_to_rpr? need full msg and not just msg.state for console.tick + #observation.state doesn't make sense, gotta figure that out + def obs_to_rpr(self, observation): + model_representation = self.q.get_state( + observation.state['NAV_X'], + observation.state['NAV_Y'], + observation.state['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation.state['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation.state['HAS_FLAG'] + ) + + console.tick(observation) + return model_representation + + def rpr_to_act(self, rpr, observation): #why rpr and observation? we talked about this but can't remember + self.current_action = self.q.get_action(rpr) + + # Determine action set + if observation['HAS_FLAG']: + current_action_set = self.retreat_actions + else: + current_action_set = self.attack_actions + + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[self.current_action]['speed'], + 'course': current_action_set[self.current_action]['course'] + } + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' + } + + return action + if __name__ == '__main__': + # Create agents required parser = argparse.ArgumentParser() parser.add_argument('--model', default=DEFAULT_RUN_MODEL) - parser.add_argument('--enemy', default='drone_21') args = parser.parse_args() - q, attack_actions, retreat_actions = load_model(args.model) - with EpisodicManager('runner', log=False) as mgr: - mgr.run(q, attack_actions, retreat_actions) + + agents = [] + wait_for = [] + for i in [1, 2, 3]: + agents.append(Agent(f'agent_1{i}', f'drone_2{i}', args.model)) + wait_for.append(f'agent_1{i}') + + console = ModelConsole() #where will this live? + mgr = EpisodicManager(agents, 13, wait_for=wait_for) + mgr.start('runner') + diff --git a/examples/QTable/run.sh b/examples/QTable/run.sh index fa5ebd8..e791702 100755 --- a/examples/QTable/run.sh +++ b/examples/QTable/run.sh @@ -30,7 +30,7 @@ echo "Launching simulation..." # Give time for simulation to startup sleep 5 -# Start trainer with any arguments passed to this script +# Start runner with any arguments passed to this script echo "Launching runner..." ./model/new_runner.py "$@" diff --git a/examples/QTable/scripts/gui_launch.sh b/examples/QTable/scripts/gui_launch.sh index 52987bf..8179d80 100755 --- a/examples/QTable/scripts/gui_launch.sh +++ b/examples/QTable/scripts/gui_launch.sh @@ -8,9 +8,17 @@ TIME_WARP="4" cd ../mission/heron # Launch a agents ./launch_heron.sh red agent 11 --color=orange $TIME_WARP > /dev/null & + ./launch_heron.sh red agent 12 --color=green $TIME_WARP > /dev/null & + ./launch_heron.sh red agent 13 --color=purple $TIME_WARP > /dev/null & + #./launch_heron.sh red agent 14 --color=gray $TIME_WARP > /dev/null & + #./launch_heron.sh red agent 15 --color=yellow $TIME_WARP > /dev/null & # Launch a blue drone ./launch_heron.sh blue drone 21 --behavior=DEFEND --color=orange $TIME_WARP > /dev/null & + ./launch_heron.sh blue drone 22 --behavior=DEFEND --color=green $TIME_WARP > /dev/null & + ./launch_heron.sh blue drone 23 --behavior=DEFEND --color=purple $TIME_WARP > /dev/null & + #./launch_heron.sh blue drone 24 --behavior=DEFEND --color=gray $LOGGING $TIME_WARP > /dev/null & + #./launch_heron.sh blue drone 25 --behavior=DEFEND --color=yellow $LOGGING $TIME_WARP > /dev/null & cd .. cd shoreside diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index f561390..cd6f245 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -1,406 +1,86 @@ -# General -import os -import time -from queue import Queue, Empty -from threading import Thread, Lock -# For core -from mivp_agent.const import KEY_ID, DATA_DIRECTORY -from mivp_agent.messages import MissionMessage, INSTR_SEND_STATE, INSTR_RESET_FAILURE, INSTR_RESET_SUCCESS -from mivp_agent.bridge import ModelBridgeServer -from mivp_agent.util.display import ModelConsole -from mivp_agent.util.math import dist -from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG +from mivp_agent.manager import MissionManager -# For logging -from mivp_agent.log.directory import LogDirectory -from mivp_agent.proto.proto_logger import ProtoLogger -from mivp_agent.proto.mivp_agent_pb2 import Transition -from mivp_agent.proto import translate +class AgentData: + def __init__(self, vname) -> None: + self.vname = vname -class EpisodicManager: ''' - This is the primary method for interfacing with moos-ivp-agent's BHV_Agent - - Examples: - It is recommended to use EpisodicManager with the python context manager - - ``` - from mivp_agent.manager import EpisodicManager - - with EpisodicManager('trainer') as mgr: - mgr.wait_for(['felix', 'evan']) - ... - ``` + Used to recognize if the new episode 'NUM' is a new one. When we have new episodes we must increment the EpisodeManager episode count. ''' + self.last_episode = None - def __init__(self, task, log=True, immediate_transition=True, log_whitelist=None, id_suffix=None, output_dir=None): - ''' - The initializer for EpisodicManager - - Args: - task (str): For organization of saved data type is required to specify what task the EpisodicManager is preforming. For example a `EpisodicManager('trainer')` will log data under `generated_files/trainer/` in the current working directory. - - log (bool): Logging of agent transitions can be disabled by setting this to `False`. - - immediate_transition (bool): By default the the manager will assume that all messages received from BHV_Agents represent a new transition. If set to `False` one must manually tell set `msg.is_transition = True` on any objects returned from `get_message()`. This is helpful when you want to control what is considered a "state" in your Markov Decsion Process. - - log_whitelist (list): Setting this parameter will only log some transitions according to their reported `vnames`. - - id_suffix (str): Will be appended to the generated session id. - - output_dir (str): Path to a place to store files. - ''' - self._msg_queue = Queue() - - self._vnames = [] - self._vname_lock = Lock() - self._vehicle_count = 0 - self._episode_manager_states = {} - self._ems_lock = Lock() - self._episode_manager_nums = {} - self._emn_lock = Lock() - - # Dict to hold queues of vnames to reset - self._vresets = Queue() - - self._thread = None - self._stop_signal = False - - if output_dir is None: - output_dir = os.path.join( - os.path.abspath(os.getcwd()), - DATA_DIRECTORY - ) - - self._log_dir = LogDirectory(output_dir) - self._id = self._init_session(id_suffix) - - # Calculate the path to the directories we will be writing to. This will be created when we first use them / return them to the user. - self._model_path = os.path.join( - self._log_dir.models_dir(), - self._id - ) - self._log_path = os.path.join( - self._log_dir.task_dir(task), - self._id - ) - self._model_path = os.path.abspath(self._model_path) - self._log_path = os.path.abspath(self._log_path) - - self._log = log - self._imm_transition = immediate_transition - if self._log: - self._log_whitelist = log_whitelist - # Create data structs needed to log data from each vehicle - self._logs = {} - self._last_state = {} - self._last_act = {} - - # Go ahead and create the log path - os.makedirs(self._log_path) - - def _init_session(self, id_suffix): - # Start the session id with the current timestamp - id = str(round(time.time())) - - # Add suffix if it exists - if id_suffix is not None: - id += f"-{id_suffix}" - - id = self._log_dir.meta.registry.register(id) - - return id - - def model_output_dir(self): - if not os.path.isdir(self._model_path): - os.makedirs(self._model_path) - return self._model_path - - def log_output_dir(self): - assert self._log, "This method should not be used, when logging is disabled" - return self._log_path - - def __enter__(self): - self.start() - return self - - def start(self): - ''' - It is **not recommended** to use this method directly. Instead, consider using this class with the python context manager. This method starts a thread to read from the `ModelBridgeServer`. - - Returns: - bool: False if thread has already been started, True otherwise - ''' - if self._thread is not None: - return False - - self._thread = Thread(target=self._server_thread, daemon=True) - self._thread.start() - - return True - - def _server_thread(self): - live_msg_list = [] - address_map = {} - with ModelBridgeServer() as server: - while not self._stop_signal: - # Accept new clients - addr = server.accept() - if addr is not None: - print(f'Got new connection: {addr}') - server.send_instr(addr, INSTR_SEND_STATE) - - # Listen for messages from vehicles - for addr in server._clients: - msg = server.listen(addr) - - if msg is not None: - with self._vname_lock: - if msg[KEY_ID] not in self._vnames: - print(f'Got new vehicle: {msg[KEY_ID]}') - vname = msg[KEY_ID] - address_map[vname] = addr - self._vnames.append(vname) - self._vehicle_count += 1 - - assert address_map[msg[KEY_ID]] == addr, "Vehicle changed vname. This violates routing / logging assumptions made by EpisodicManager" - - m = MissionMessage( - addr, - msg, - is_transition=self._imm_transition - ) - - with self._ems_lock: - self._episode_manager_states[m.vname] = m.episode_state - with self._emn_lock: - if m.episode_report is None: - self._episode_manager_nums[m.vname] = None - else: - self._episode_manager_nums[m.vname] = m.episode_report['NUM'] - - live_msg_list.append(m) - self._msg_queue.put(m) - - # Send responses to vehicle message if there are any - for i, m in enumerate(live_msg_list): - with m._rsp_lock: - if m._response is None: - continue - - # If we got there is response send and remove from list - live_msg_list.remove(m) - server.send_instr(m._addr, m._response) - - # Do logging - self._do_logging(m) - - # Handle reseting of vehicles - while not self._vresets.empty(): - vname, success = self._vresets.get() - - if vname not in address_map: - raise RuntimeError( - f'Received reset for unknown vehicle: {vname}') - - instr = INSTR_RESET_FAILURE - if success: - instr = INSTR_RESET_SUCCESS - - server.send_instr(address_map[vname], instr) - - # This message should only be called on msgs which have actions - def _do_logging(self, msg): - if not self._log: - return - - # Check in whitelist if exists - if self._log_whitelist is not None: - if msg.vname not in self._log_whitelist: - return - - # Check if this is a new vehicle - if msg.vname not in self._logs: - path = os.path.join(self._log_path, f"log_{msg.vname}") - self._logs[msg.vname] = ProtoLogger(path, Transition, mode='w') - - if msg._is_transition: - # Write a transition if this is not the first state ever - if msg.vname in self._last_state: - t = Transition() - t.s1.CopyFrom(translate.state_from_dict(self._last_state[msg.vname])) - t.a.CopyFrom(translate.action_from_dict(self._last_act[msg.vname])) - t.s2.CopyFrom(translate.state_from_dict(msg.state)) - - self._logs[msg.vname].write(t) - - # Update the storage for next transition - self._last_state[msg.vname] = msg.state - self._last_act[msg.vname] = msg._response - - def are_present(self, vnames): - ''' - Used to see if a specified list of vehicles has connected to the `EpisodicManager` instance yet. - - See also: [`wait_for()`][mivp_agent.manager.EpisodicManager.wait_for] - - Args: - vnames (iterable): A list / tuple of `str` values to look for - ''' - for vname in vnames: - with self._vname_lock: - if vname not in self._vnames: - return False - return True - - def wait_for(self, vnames, sleep=0.1): - ''' - Used to block until a specified list of vehicles has connect to the `EpisodicManager` instance. - - Args: - vnames (iterable): A list / tuple of `str` values to look for - sleep (float): Amount of time in seconds to sleep for between checks - ''' - while not self.are_present(vnames): - time.sleep(sleep) - - def get_message(self, block=True): - ''' - Used as the primary method for receiving data from `BHV_Agent`. - - **NOTE:** Messages **MUST** be responded to as `BHV_Agent` will not send another update until it has a response to the last. - - Args: - block (bool): A boolean specifying if the method will wait until a message present or return immediately + ''' + General idea: only call `rpr_to_act` when we reach a new state. Need to store the first action calculated for the state. + ''' + self.current_action = None - Returns: - obj: A instance of [`MissionMessage()`][mivp_agent.manager.MissionMessage] or `None` depending on the blocking behavior + ''' + To identify new states and preform action lookup / calculation from the model. + ''' + self.last_rpr = None - Example: - ``` - msg = mgr.get_message() - NAV_X = msg.state['NAV_X'] - NAV_Y = msg.state['NAV_Y'] +class EpisodicManager: + def __init__(self, agents, episodes, wait_for=None) -> None: + ''' + SETUP AGENTS + ''' + self.agents = agents + self.agent_data = {} - # ... - # Some processing - # ... + # Combine the agent's id and any additional wait fors + self.wait_for = wait_for - msg.act({ - 'speed': 1.0 - 'course': 180.0 - }) - ``` - ''' - try: - return self._msg_queue.get(block=block) - except Empty: - return None + if self.wait_for is None: + self.wait_for = [] + # Setup things specific to each agent - def get_vehicle_count(self): - ''' - Returns: - int: The amount of vehicles that have connected to this instance of `EpisodicManager` - ''' - return self._vehicle_count + for a in self.agents: + # Make sure they are waited for + self.wait_for.append(a.id()) + # Make a agent state for them + self.agent_data[a.id()] = AgentData(a.id()) - def episode_state(self, vname): - ''' - This is used to interrogate the state of a connected vehicle's `pEpisodeManager` + ''' + SETUP EPISODE TRACKING + ''' + self.episodes = episodes + self.current_episode = 0 - Args: - vname (str): the vname of the vehicle + def start(self, task, log=True): + with MissionManager(task, log=log) as mgr: + mgr.wait_for(self.wait_for) - Returns: - str: The state of the `pEpisodeManager` on the vehicle - ''' - with self._ems_lock: - # Should be all strings so no reference odd ness - return self._episode_manager_states[vname] + while self.current_episode < self.episodes: + msg = mgr.get_message() - def episode_nums(self): - ''' - Returns: - dict: A key, value pair maping vnames to the episode numbers of the `pEpisodeManager` app on that vehicle - ''' - with self._emn_lock: - return self._episode_manager_nums.copy() + # Find agent in list... + for a in self.agents: + if msg.vname == a.id(): + data = self.agent_data[a.id()] - def reset_vehicle(self, vname, success=False): - # Untested - self._vresets.append((vname, success)) + # Probably always + # rpr = a.obs_to_rpr(msg.state) #state nomenclature still lurking + rpr = a.obs_to_rpr(msg) # full msg for console.tick - def run(self, q, attack_actions, retreat_actions): - #q, attack_actions, retreat_actions = load_model(args.model) + if data.last_rpr != rpr: + msg.mark_transition() + # still need state here bc rpr_to_act expects obs + data.current_action = a.rpr_to_act(rpr, msg.state) - print('Waiting for sim vehicle connections...') - while self.get_vehicle_count() < 1: - time.sleep(0.1) - # --------------------------------------- - # Part 1: Asserting simulation state - - last_state = None - current_action = None - current_action_set = None - console = ModelConsole() - - while True: - # Listen for state - msg = self.get_message() - while False: - print('-------------------------------------------') - print(f"({msg.vname}) {msg.state['HAS_FLAG']}") - print('-------------------------------------------') - msg.request_new() - msg = mgr.get_message() - - console.tick(msg) - - # Detect state transitions - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS']['drone_21']['NAV_X'], - msg.state['NODE_REPORTS']['drone_21']['NAV_Y'], - msg.state['HAS_FLAG'] - ) - - # Detect state transition - if model_state != last_state: - current_action = q.get_action(model_state) - last_state = model_state - - # Determine action set - if msg.state['HAS_FLAG']: - current_action_set = retreat_actions - else: - current_action_set = attack_actions - - # Construct instruction for BHV_Agent - action = { - 'speed': current_action_set[current_action]['speed'], - 'course': current_action_set[current_action]['course'] - } - - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - if flag_dist < 10: - action['posts']= { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } - - msg.act(action) + # Update episode count if applicable + if data.last_episode != msg.episode_report['NUM']: + data.last_episode = msg.episode_report['NUM'] - def close(self): - if self._thread is not None: - self._stop_signal = True - self._thread.join() - if self._log: - for vehicle in self._logs: - self._logs[vehicle].close() + ################################################ + # Importantly, update the global episode count # + self.current_episode += 1 + # track data + data.last_rpr = rpr - def __exit__(self, exc_type, exc_value, traceback): - self.close() + ################################################ + # Importantly, actually do shit # + msg.act(data.current_action) From 612b1d10c49bcf0c7df30c2a420158a7223e7868 Mon Sep 17 00:00:00 2001 From: Joseph Doyle Date: Sun, 20 Mar 2022 02:09:13 -0400 Subject: [PATCH 4/9] removed console.tick(msg) for consistency --- examples/QTable/model/new_runner.py | 29 ++++++------------- .../src/mivp_agent/episodic_manager.py | 4 +-- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/QTable/model/new_runner.py b/examples/QTable/model/new_runner.py index 874b52c..521b34e 100755 --- a/examples/QTable/model/new_runner.py +++ b/examples/QTable/model/new_runner.py @@ -4,7 +4,7 @@ from mivp_agent.episodic_manager import EpisodicManager from mivp_agent.util.math import dist -from mivp_agent.util.display import ModelConsole +#from mivp_agent.util.display import ModelConsole from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG from constants import DEFAULT_RUN_MODEL from model import load_model @@ -19,27 +19,16 @@ def __init__(self, own_id, opponent_id, model) -> None: def id(self): return self.own_id - ''' - Idea of the method bellow... - Trainer was previously had to deal with state transitions itself. - msg1 - Message from v1 at time 1 - msg2 - Message from v1 at time 2 - # We can find when transitions by doing this - obs_to_rpr(msg1.state) != obs_to_rpr(msg2.state) - previous_state != obs_to_rpr(blah blah) - ''' - #rename to msg_to_rpr? need full msg and not just msg.state for console.tick - #observation.state doesn't make sense, gotta figure that out def obs_to_rpr(self, observation): model_representation = self.q.get_state( - observation.state['NAV_X'], - observation.state['NAV_Y'], - observation.state['NODE_REPORTS'][self.opponent_id]['NAV_X'], - observation.state['NODE_REPORTS'][self.opponent_id]['NAV_Y'], - observation.state['HAS_FLAG'] + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] ) + #console.tick(observation) #needs full msg, obervation is an msg.state - console.tick(observation) return model_representation def rpr_to_act(self, rpr, observation): #why rpr and observation? we talked about this but can't remember @@ -79,7 +68,7 @@ def rpr_to_act(self, rpr, observation): #why rpr and observation? we talked abou agents.append(Agent(f'agent_1{i}', f'drone_2{i}', args.model)) wait_for.append(f'agent_1{i}') - console = ModelConsole() #where will this live? - mgr = EpisodicManager(agents, 13, wait_for=wait_for) + #console = ModelConsole() #where will this live? Do we care? Needs a full msg instead of msg.state + mgr = EpisodicManager(agents, 13, wait_for=wait_for) #13 for 10 full episodes... not ideal mgr.start('runner') diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index cd6f245..e182d54 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -62,8 +62,8 @@ def start(self, task, log=True): data = self.agent_data[a.id()] # Probably always - # rpr = a.obs_to_rpr(msg.state) #state nomenclature still lurking - rpr = a.obs_to_rpr(msg) # full msg for console.tick + rpr = a.obs_to_rpr(msg.state) #state nomenclature still lurking + #rpr = a.obs_to_rpr(msg) # full msg for console.tick if data.last_rpr != rpr: msg.mark_transition() From f18867685892d73864e691e87f8a462235c5a46a Mon Sep 17 00:00:00 2001 From: "carter.fendley" Date: Sat, 26 Mar 2022 17:11:13 -0400 Subject: [PATCH 5/9] Add simple test for episodic manager constructor --- src/python_module/test/test_all.py | 2 ++ .../test/test_episodic_manager.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 src/python_module/test/test_episodic_manager.py diff --git a/src/python_module/test/test_all.py b/src/python_module/test/test_all.py index 00c087a..90f3313 100755 --- a/src/python_module/test/test_all.py +++ b/src/python_module/test/test_all.py @@ -15,6 +15,7 @@ import test_bridge import test_log import test_manager +import test_episodic_manager import test_data_structures import test_proto import test_consumer @@ -36,6 +37,7 @@ suite.addTest(unittest.makeSuite(test_consumer.TestConsumer)) suite.addTest(unittest.makeSuite(test_manager.TestManagerCore)) suite.addTest(unittest.makeSuite(test_manager.TestManagerLogger)) + suite.addTest(unittest.makeSuite(test_episodic_manager.TestEpisodicManager)) suite.addTest(unittest.makeSuite(test_data_structures.TestLimitedHistory)) suite.addTest(unittest.makeSuite(test_proto.TestLogger)) diff --git a/src/python_module/test/test_episodic_manager.py b/src/python_module/test/test_episodic_manager.py new file mode 100644 index 0000000..40c9de6 --- /dev/null +++ b/src/python_module/test/test_episodic_manager.py @@ -0,0 +1,28 @@ +from mivp_agent.episodic_manager import EpisodicManager + +import unittest + +class FakeAgent: + def __init__(self, id): + self._id = id + + def id(self): + return self._id + +class TestEpisodicManager(unittest.TestCase): + def test_constructor(self): + # Test failure with no args + self.assertRaises(TypeError, EpisodicManager) + + agents = [] + agents.append(FakeAgent('joe')) + agents.append(FakeAgent('carter')) + + # Test failure with with no episodes + self.assertRaises(TypeError, EpisodicManager, agents) + + # Test success + EpisodicManager(agents, 100) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 2fb423fe0f1d215728631b3353c4e3d21a025020 Mon Sep 17 00:00:00 2001 From: "carter.fendley" Date: Sun, 27 Mar 2022 15:19:41 -0400 Subject: [PATCH 6/9] Add stopping mechanism to EpisodicManager --- .../src/mivp_agent/episodic_manager.py | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index e182d54..cfd8e06 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -1,3 +1,5 @@ +import time +from threading import Event, Lock from mivp_agent.manager import MissionManager @@ -17,7 +19,7 @@ def __init__(self, vname) -> None: self.current_action = None ''' - To identify new states and preform action lookup / calculation from the model. + To identify new states and perform action lookup / calculation from the model. ''' self.last_rpr = None @@ -49,12 +51,46 @@ def __init__(self, agents, episodes, wait_for=None) -> None: self.episodes = episodes self.current_episode = 0 + # Control signals + self._run_lock = Lock() + self._stop_signal = Event() + + ''' + Perform non blocking acquire on the run lock to test if the lock is acquired by the run method. The lock will be released right after acquisition. + ''' + def is_running(self): + if self._run_lock.acquire(False): + self._run_lock.release() + return False + return True + + def _should_stop(self): + return self.current_episode >= self.episodes or \ + self._stop_signal.is_set() + + def stop(self): + if not self.is_running(): + RuntimeError('Stop called before start.') + self._stop_signal.set() + def start(self, task, log=True): + if not self._run_lock.acquire(False): + raise RuntimeError('Start should only be called once.') + with MissionManager(task, log=log) as mgr: - mgr.wait_for(self.wait_for) + # Below is similar to `mgr.wait_for(...)` but respects out stop signal + while not mgr.are_present(self.wait_for) and \ + not self._should_stop(): + time.sleep(0.1) + + while not self._should_stop(): + # Non blocking so `stop()` method will work immediately + msg = mgr.get_message(block=False) - while self.current_episode < self.episodes: - msg = mgr.get_message() + # If we didn't get a message sleep and then loop + if msg is None: + time.sleep(0.1) + continue # Find agent in list... for a in self.agents: @@ -84,3 +120,5 @@ def start(self, task, log=True): ################################################ # Importantly, actually do shit # msg.act(data.current_action) + + self._run_lock.release() From 67dd52b48cc6db380ab896288c00fd0911e31a72 Mon Sep 17 00:00:00 2001 From: "carter.fendley" Date: Sun, 27 Mar 2022 15:20:04 -0400 Subject: [PATCH 7/9] Allow None to be returned from rpr_to_act --- src/python_module/src/mivp_agent/episodic_manager.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index cfd8e06..db33bf4 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -117,8 +117,10 @@ def start(self, task, log=True): # track data data.last_rpr = rpr - ################################################ - # Importantly, actually do shit # - msg.act(data.current_action) + # If we have an action send that, otherwise mark message as handled and request new one. + if data.current_action is None: + msg.request_new() + else: + msg.act(data.current_action) self._run_lock.release() From ef74f9e60da53a35900f221f94031ed429219bd1 Mon Sep 17 00:00:00 2001 From: "carter.fendley" Date: Sun, 27 Mar 2022 15:39:26 -0400 Subject: [PATCH 8/9] Add more complex testing and some fixtures --- .../test/test_episodic_manager.py | 100 +++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/src/python_module/test/test_episodic_manager.py b/src/python_module/test/test_episodic_manager.py index 40c9de6..03f0864 100644 --- a/src/python_module/test/test_episodic_manager.py +++ b/src/python_module/test/test_episodic_manager.py @@ -1,15 +1,78 @@ +import unittest +import time +from threading import Thread + +from mivp_agent.bridge import ModelBridgeClient from mivp_agent.episodic_manager import EpisodicManager +from mivp_agent.const import KEY_ID, KEY_EPISODE_MGR_REPORT, KEY_EPISODE_MGR_STATE -import unittest +DUMMY_STATE = { + KEY_ID: 'felix', + 'MOOS_TIME': 16923.012, + 'NAV_X': 98.0, + 'NAV_Y': 40.0, + 'NAV_HEADING': 180, + KEY_EPISODE_MGR_REPORT: 'NUM=0,DURATION=60.57,SUCCESS=false,WILL_PAUSE=false', + KEY_EPISODE_MGR_STATE: 'PAUSED' +} class FakeAgent: def __init__(self, id): self._id = id + self.states = [] def id(self): return self._id + + def obs_to_rpr(self, state): + self.states.append(state) + + def rpr_to_act(self, rpr, state): + pass class TestEpisodicManager(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.agents = { + 'misha': FakeAgent('misha'), + 'mike': FakeAgent('mike') + } + + # For each agent create a client + self.clients = {} + for name in self.agents: + self.clients[name] = ModelBridgeClient() + + # Get list of agent names for episodic manager + agnt_lst = [agent for _, agent in self.agents.items()] + self.mgr = EpisodicManager(agnt_lst, 100) + + # Start manager in another thread + self.mgr_thread = Thread(target=self.mgr.start, args=('tester',)) + self.mgr_thread.start() + + # Connect clients to server started by the manager + all_connected = False + while not all_connected: + all_connected = True + + for c in self.clients: + if not self.clients[c].is_connected(): + self.clients[c].connect() + all_connected = False + + def tearDown(self) -> None: + super().tearDown() + if self.mgr.is_running(): + self.mgr.stop() + + self.mgr_thread.join() + + time.sleep(0.1) + + for c in self.clients: + self.clients[c].close() + def test_constructor(self): # Test failure with no args self.assertRaises(TypeError, EpisodicManager) @@ -24,5 +87,40 @@ def test_constructor(self): # Test success EpisodicManager(agents, 100) + def test_stop(self): + # Sanity check + self.assertTrue(self.mgr.is_running()) + + # Give the stop signal + self.mgr.stop() + + # Allow time for release then assert + time.sleep(0.2) + self.assertFalse(self.mgr.is_running()) + + def test_routing(self): + # Setup two states to be sent + misha_state = DUMMY_STATE.copy() + misha_state[KEY_ID] = 'misha' + misha_state['NAV_X'] = 1.0 + mike_state = DUMMY_STATE.copy() + mike_state[KEY_ID] = 'mike' + mike_state['NAV_X'] = 0.0 + + # Send states + self.clients['misha'].send_state(misha_state) + self.clients['mike'].send_state(mike_state) + + # Allow time for propagation + time.sleep(0.2) + + # Check that both received one state + self.assertEqual(len(self.agents['misha'].states), 1) + self.assertEqual(len(self.agents['mike'].states), 1) + + # Check that they received the proper state + self.assertEqual(self.agents['misha'].states[0]['NAV_X'], 1.0) + self.assertEqual(self.agents['mike'].states[0]['NAV_X'], 0.0) + if __name__ == '__main__': unittest.main() \ No newline at end of file From a74934177908dcd892b3f3fb1a9cae84cf6bcb34 Mon Sep 17 00:00:00 2001 From: Joe Doyle Date: Fri, 15 Apr 2022 10:42:27 -0400 Subject: [PATCH 9/9] minimal episodic manager functioning with runner and trainer --- examples/QTable/model/runner.py | 120 +++--- examples/QTable/model/trainer.py | 369 +++++++----------- .../src/mivp_agent/episodic_manager.py | 73 ++-- 3 files changed, 214 insertions(+), 348 deletions(-) diff --git a/examples/QTable/model/runner.py b/examples/QTable/model/runner.py index d5fe1b6..152c089 100755 --- a/examples/QTable/model/runner.py +++ b/examples/QTable/model/runner.py @@ -2,79 +2,73 @@ import argparse import time -from mivp_agent.manager import MissionManager +from mivp_agent.episodic_manager import EpisodicManager from mivp_agent.util.math import dist -from mivp_agent.util.display import ModelConsole +#from mivp_agent.util.display import ModelConsole from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG - from constants import DEFAULT_RUN_MODEL from model import load_model +class Agent: + def __init__(self, own_id, opponent_id, model) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q, self.attack_actions, self.retreat_actions = load_model(model) + self.current_action = None + + def id(self): + return self.own_id + + + def obs_to_rpr(self, observation): + model_representation = self.q.get_state( + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] + ) + #console.tick(observation) #needs full msg, obervation is an msg.state + + return model_representation + + def rpr_to_act(self, rpr, observation): + self.current_action = self.q.get_action(rpr) + + # Determine action set + if observation['HAS_FLAG']: + current_action_set = self.retreat_actions + else: + current_action_set = self.attack_actions -def run(args): - q, attack_actions, retreat_actions = load_model(args.model) - - with MissionManager('runner', log=False) as mgr: - print('Waiting for sim vehicle connections...') - while mgr.get_vehicle_count() < 1: - time.sleep(0.1) - # --------------------------------------- - # Part 1: Asserting simulation state - - last_state = None - current_action = None - current_action_set = None - console = ModelConsole() - - while True: - # Listen for state - msg = mgr.get_message() - while False: - print('-------------------------------------------') - print(f"({msg.vname}) {msg.state['HAS_FLAG']}") - print('-------------------------------------------') - msg.request_new() - msg = mgr.get_message() - - console.tick(msg) - - # Detect state transitions - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_X'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_Y'], - msg.state['HAS_FLAG'] - ) - - # Detect state transition - if model_state != last_state: - current_action = q.get_action(model_state) - last_state = model_state - - # Determine action set - if msg.state['HAS_FLAG']: - current_action_set = retreat_actions - else: - current_action_set = attack_actions - - # Construct instruction for BHV_Agent - action = { - 'speed': current_action_set[current_action]['speed'], - 'course': current_action_set[current_action]['course'] + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[self.current_action]['speed'], + 'course': current_action_set[self.current_action]['course'] + } + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' } - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - if flag_dist < 10: - action['posts']= { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } - - msg.act(action) + return action + if __name__ == '__main__': + # Create agents required parser = argparse.ArgumentParser() parser.add_argument('--model', default=DEFAULT_RUN_MODEL) - parser.add_argument('--enemy', default='drone_21') args = parser.parse_args() - run(args) \ No newline at end of file + + agents = [] + wait_for = [] + for i in [1, 2, 3]: + agents.append(Agent(f'agent_1{i}', f'drone_2{i}', args.model)) + + #console = ModelConsole() #where will this live? Do we care? Needs a full msg instead of msg.state + mgr = EpisodicManager(agents, 13, wait_for=wait_for) #13 for 10 full episodes... not ideal + mgr.start('runner') + diff --git a/examples/QTable/model/trainer.py b/examples/QTable/model/trainer.py index 551aa81..ca398fc 100755 --- a/examples/QTable/model/trainer.py +++ b/examples/QTable/model/trainer.py @@ -1,16 +1,9 @@ #!/usr/bin/env python3 -import os -import time -import wandb import argparse -from tqdm import tqdm - -from mivp_agent.manager import MissionManager -from mivp_agent.util.display import ModelConsole +from mivp_agent.episodic_manager import EpisodicManager from mivp_agent.util.math import dist -from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG - from wandb_key import WANDB_KEY +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG from constants import LEARNING_RATE, DISCOUNT, EPISODES from constants import FIELD_RESOLUTION from constants import EPSILON_START, EPSILON_DECAY_START, EPSILON_DECAY_AMT, EPSILON_DECAY_END @@ -24,27 +17,22 @@ 'agent_11': 'drone_21', 'agent_12': 'drone_22', 'agent_13': 'drone_23', - 'agent_14': 'drone_24', - 'agent_15': 'drone_25' + #'agent_14': 'drone_24', + #'agent_15': 'drone_25' } EXPECTED_VEHICLES = [key for key in VEHICLE_PAIRING] -class AgentData: - ''' - Used to encapsulate the data needed to run each indivdual - agent, track state / episode transitions, and output useful - information and statistics. - ''' - def __init__(self, vname, enemy): - self.vname = vname - self.enemy = enemy - - # For running of simulation - self.agent_episode_count = 0 - self.last_episode_num = None # Episode transitions - self.last_state = None # State transitions - self.had_flag = False # Capturing grab transitions for rewarding - self.current_action = None + +class Agent: + def __init__(self, own_id, opponent_id, q, config) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q = q + self.attack_actions = config['attack_actions'] + self.retreat_actions = config['retreat_actions'] + self.current_action = None + self.config = config + self.last_rpr = None # For debugging / output self.min_dist = None @@ -52,225 +40,126 @@ def __init__(self, vname, enemy): self.last_MOOS_time = None self.MOOS_deltas = [] self.grab_time = None - - def new_episode(self, last_num): - self.last_episode_num = last_num - self.agent_episode_count += 1 - self.last_state = None + def id(self): + ''' + Required hook for EpisodicManager to access agent ID + ''' + return self.own_id + + def reset_tracking_vars(self): + ''' + Reset per-episode tracking vars for this agent + ''' + #Functional + self.last_rpr = None self.had_flag = False self.current_action = None + #Debugging and I/O self.min_dist = None self.episode_reward = 0 self.last_MOOS_time = None self.MOOS_deltas.clear() self.grab_time = None - -def train(args, config, run_name): - agents = {} - with MissionManager('trainer', log=True, immediate_transition=False, id_suffix=run_name) as mgr: - # Create a directory for the model to save - model_save_dir = mgr.model_output_dir() - - # Setup model - q = QLearn( - lr=config['lr'], - gamma=config['gamma'], - action_space_size=config['action_space_size'], - field_res=config['field_res'], - verbose=args.debug, - save_dir=model_save_dir + def obs_to_rpr(self, observation): + ''' + Hook for EpisodicManager to convert an observation to python representation + Called on each new oberservation, used to detect state transitions + ''' + model_representation = self.q.get_state( + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] ) - print('Waiting for sim vehicle connection...') - mgr.wait_for(EXPECTED_VEHICLES) - - # Construct agent data object with enemy specified - agents = {} - for a in VEHICLE_PAIRING: - agents[a] = AgentData(a, VEHICLE_PAIRING[a]) - - # While all vehicle's pEpisodeManager are not PAUSED - print('Waiting for pEpisodeManager to enter PAUSED state...') - while not all(mgr.episode_state(vname) == 'PAUSED' for vname in agents): - msg = mgr.get_message() - if mgr.episode_state(msg.vname) != 'PAUSED': - # Instruct them to stop and pause - msg.stop() - else: - # O/w just keep up to date on their state - msg.request_new() - - # --------------------------------------- - # Part 2: Global state initalization - episode_count = 0 - epsilon = config['epsilon_start'] - progress_bar = tqdm(total=config['episodes'], desc='Training') - # Debugging - total_sim_time = 0 - - # Initalize the epidode numbers from pEpisodeManager - e_nums = mgr.episode_nums() - for a in e_nums: - agents[a].last_episode_num = e_nums[a] - - print('Running training...') - while episode_count < config['episodes']: - ''' - This loop handles an indivual msg from an indivual agent - data from each agent is used during training - ''' - # Listen for an agent's message & get it's data - msg = mgr.get_message() - # If pEpisodeManager is paused, start and continue to next agent - if msg.episode_state == 'PAUSED': - msg.mark_transition() # Initial state should be a transition - msg.start() - continue - agent_data = agents[msg.vname] - # Update debugging MOOS time - if agent_data.last_MOOS_time is not None: - agent_data.MOOS_deltas.append(msg.state['MOOS_TIME']-agent_data.last_MOOS_time) - agent_data.last_MOOS_time = msg.state['MOOS_TIME'] - - ''' - Part 1: Translate MOOS state to model's state representation - ''' - #print(msg.state['HAS_FLAG']) - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS'][agent_data.enemy]['NAV_X'], - msg.state['NODE_REPORTS'][agent_data.enemy]['NAV_Y'], - msg.state['HAS_FLAG'] + return model_representation + + def rpr_to_act(self, rpr, observation, em_report): + ''' + Hook for EpisodicManager to get next action based on current state representation + Called on each new state representation + Most functional code goes here + ''' + # Update previous state action pair + reward = self.config['reward_step'] + if observation['HAS_FLAG'] and not self.had_flag: + reward = self.config['reward_grab'] #Should this be += to also account for the step? + self.had_flag = True + # Capture time for debugging (time after grab) + self.grab_time = observation['MOOS_TIME'] + + if self.last_rpr is not None: + self.q.update_table( + self.last_rpr, + self.current_action, + reward, + rpr ) + self.episode_reward += self.config['reward_step'] #shouldn't this just be += reward? - # Detect discrete state transitions - if model_state != agent_data.last_state: - ''' - Part 2: Handle the ending of episodes - ''' - # Mark this state as a transition to record it to logs - msg.mark_transition() - - if msg.episode_report is None: - assert agent_data.agent_episode_count == 0 - elif msg.episode_report['DURATION'] < 2: - # Bad episode, don't use data - # Reset episode data and including 'last_episode_num' - agent_data.new_episode(msg.episode_report['NUM']) - elif msg.episode_report['NUM'] != agent_data.last_episode_num: - # Calculate reward based on pEpisodeManager's report - reward = config['reward_failure'] - if msg.episode_report['SUCCESS']: - reward = config['reward_capture'] - - # Apply this reward to the QTable - q.set_qvalue( - agent_data.last_state, - agent_data.current_action, - reward - ) - - # Update the total sim time - total_sim_time += msg.episode_report['DURATION'] - - # Construct report - report = { - 'episode_count': episode_count, - 'reward': agent_data.episode_reward+reward, - 'epsilon': round(epsilon, 3), - 'duration': round(msg.episode_report['DURATION'],2), - 'success': msg.episode_report['SUCCESS'], - 'min_dist': round(agent_data.min_dist, 2), - 'had_flag': agent_data.had_flag, - 'sim_time': round(total_sim_time, 2), - 'sim_days': round(total_sim_time / 86400, 2) - } - - if len(agent_data.MOOS_deltas) != 0: - report['avg_delta'] = round(sum(agent_data.MOOS_deltas)/len(agent_data.MOOS_deltas),2) - else: - report['avg_delta'] = 0.0 - - if agent_data.grab_time is not None: - report['post_grab_duration'] = round(msg.state['MOOS_TIME'] - agent_data.grab_time, 2) - else: - report['post_grab_duration'] = 0.0 - - # Log the report - if not args.no_wandb: - wandb.log(report) - tqdm.write(f'[{msg.vname}] ', end='') - tqdm.write(', '.join([f'{k}: {report[k]}' for k in report])) - - # Decay epsilon - if config['epsilon_decay_end'] >= episode_count >= config['epsilon_decay_start']: - epsilon -= config['epsilon_decay_amt'] - - # Reset episode data and including 'last_episode_num' - agent_data.new_episode(msg.episode_report['NUM']) - # Update global episode count - episode_count += 1 - progress_bar.update(1) - - # Save model if applicable - if episode_count % SAVE_EVERY == 0: - q.save( - config['attack_actions'], - config['retreat_actions'], - name=f'episode_{episode_count}' - ) + #################### NO #################### + global epsilon + ############# NO NO NO NO NO NO ############ + + self.current_action = q.get_action(rpr, e=epsilon) + self.last_rpr = rpr + + actions = self.attack_actions + if observation['HAS_FLAG']: # Use retreat actions if already grabbed and... retreating + actions = self.retreat_actions + action = actions[self.current_action].copy() # Copy out of reference paranoia + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' + } + + return action + + def episode_end(self, rpr, observation, em_report): + ''' + Hook for EpisodicManager, called at the end of each episode + End rewards, model saving, etc. go here + ''' + reward = self.config['reward_failure'] + if em_report['success']: + reward = self.config['reward_capture'] + + # Apply this reward to the QTable + if self.last_rpr is not None: + self.q.set_qvalue( + self.last_rpr, + self.current_action, + reward + ) + else: + print("last_rpr was None") - ''' - Part 3: Handle updating actions / qtable in new states - ''' + completed_episodes = em_report['completed_episodes'] - # Update previous state action pair - reward = config['reward_step'] - if msg.state['HAS_FLAG'] and not agent_data.had_flag: - reward = config['reward_grab'] - agent_data.had_flag = True - # Capture time for debugging (time after grab) - agent_data.grab_time = msg.state['MOOS_TIME'] + #################### NO #################### + global epsilon + ############# NO NO NO NO NO NO ############ - if agent_data.last_state is not None: - q.update_table( - agent_data.last_state, - agent_data.current_action, - reward, - model_state - ) - agent_data.episode_reward += config['reward_step'] - - # Update tracking data - agent_data.current_action = q.get_action(model_state, e=epsilon) - agent_data.last_state = model_state - - ''' - Part 4: Even when agent is not in new state, keep preforming - the action that was calcualted on the when the state transitioned - ''' - actions = config['attack_actions'] - if msg.state['HAS_FLAG']: # Use retreat actions if already grabbed and... retreating - actions = config['retreat_actions'] - action = actions[agent_data.current_action].copy() # Copy out of reference paranoia - - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - # If this agent can grab the flag, do so - if flag_dist < 10: - action['posts'] = { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } + # Decay epsilon + if self.config['epsilon_decay_end'] >= completed_episodes >= self.config['epsilon_decay_start']: + epsilon -= self.config['epsilon_decay_amt'] - # Send action - msg.act(action) + self.reset_tracking_vars() - # Debugging stuff - if agent_data.min_dist is None or agent_data.min_dist > flag_dist: - agent_data.min_dist = flag_dist + # Save model if applicable + if completed_episodes % config['save_every'] == 0: + q.save( + config['attack_actions'], + config['retreat_actions'], + name=f'episode_{completed_episodes}' + ) if __name__ == '__main__': @@ -301,12 +190,24 @@ def train(args, config, run_name): 'reward_capture': REWARD_CAPTURE, 'reward_failure': REWARD_FAILURE, 'reward_step': REWARD_STEP, + 'save_every': SAVE_EVERY, } - if args.no_wandb: - train(args, config, None) - else: - wandb.login(key=WANDB_KEY) - with wandb.init(project='mivp_agent_qtable', config=config): - config = wandb.config - train(args, config, f'{wandb.run.name}') \ No newline at end of file + # Setup model + q = QLearn( + lr=config['lr'], + gamma=config['gamma'], + action_space_size=config['action_space_size'], + field_res=config['field_res'], + verbose=args.debug, + save_dir=SAVE_DIR + ) + epsilon = config['epsilon_start'] + + agents = [] + wait_for = [] + for key in VEHICLE_PAIRING: + agents.append(Agent(key, VEHICLE_PAIRING[key], q, config)) + + mgr = EpisodicManager(agents, EPISODES, wait_for=wait_for) + mgr.start('trainer') \ No newline at end of file diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py index db33bf4..19dcd47 100644 --- a/src/python_module/src/mivp_agent/episodic_manager.py +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -1,5 +1,3 @@ -import time -from threading import Event, Lock from mivp_agent.manager import MissionManager @@ -49,48 +47,21 @@ def __init__(self, agents, episodes, wait_for=None) -> None: SETUP EPISODE TRACKING ''' self.episodes = episodes - self.current_episode = 0 - - # Control signals - self._run_lock = Lock() - self._stop_signal = Event() - - ''' - Perform non blocking acquire on the run lock to test if the lock is acquired by the run method. The lock will be released right after acquisition. - ''' - def is_running(self): - if self._run_lock.acquire(False): - self._run_lock.release() - return False - return True - - def _should_stop(self): - return self.current_episode >= self.episodes or \ - self._stop_signal.is_set() - - def stop(self): - if not self.is_running(): - RuntimeError('Stop called before start.') - self._stop_signal.set() + self.completed_episode = 0 - def start(self, task, log=True): - if not self._run_lock.acquire(False): - raise RuntimeError('Start should only be called once.') + def _build_report(self): + report = { + 'completed_episodes': self.completed_episode, + } - with MissionManager(task, log=log) as mgr: - # Below is similar to `mgr.wait_for(...)` but respects out stop signal - while not mgr.are_present(self.wait_for) and \ - not self._should_stop(): - time.sleep(0.1) + return report - while not self._should_stop(): - # Non blocking so `stop()` method will work immediately - msg = mgr.get_message(block=False) + def start(self, task, log=True): + with MissionManager(task, log=log) as mgr: + mgr.wait_for(self.wait_for) - # If we didn't get a message sleep and then loop - if msg is None: - time.sleep(0.1) - continue + while self.completed_episode < self.episodes: + msg = mgr.get_message() # Find agent in list... for a in self.agents: @@ -103,24 +74,24 @@ def start(self, task, log=True): if data.last_rpr != rpr: msg.mark_transition() + em_report = self._build_report() # still need state here bc rpr_to_act expects obs - data.current_action = a.rpr_to_act(rpr, msg.state) + data.current_action = a.rpr_to_act(rpr, msg.state, em_report) # Update episode count if applicable if data.last_episode != msg.episode_report['NUM']: + if data.last_episode is not None: + # update the global episode count # if not first + self.completed_episode += 1 data.last_episode = msg.episode_report['NUM'] + em_report = self._build_report() + em_report['success'] = msg.episode_report['SUCCESS'] + a.episode_end(rpr, msg.state, em_report) - ################################################ - # Importantly, update the global episode count # - self.current_episode += 1 # track data data.last_rpr = rpr - # If we have an action send that, otherwise mark message as handled and request new one. - if data.current_action is None: - msg.request_new() - else: - msg.act(data.current_action) - - self._run_lock.release() + ################################################ + # Importantly, actually do shit # + msg.act(data.current_action)