diff --git a/examples/QTable/model/new_runner.py b/examples/QTable/model/new_runner.py new file mode 100755 index 0000000..521b34e --- /dev/null +++ b/examples/QTable/model/new_runner.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +import argparse +import time + +from mivp_agent.episodic_manager import EpisodicManager +from mivp_agent.util.math import dist +#from mivp_agent.util.display import ModelConsole +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG +from constants import DEFAULT_RUN_MODEL +from model import load_model + +class Agent: + def __init__(self, own_id, opponent_id, model) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q, self.attack_actions, self.retreat_actions = load_model(model) + self.current_action = None + + def id(self): + return self.own_id + + def obs_to_rpr(self, observation): + model_representation = self.q.get_state( + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] + ) + #console.tick(observation) #needs full msg, obervation is an msg.state + + return model_representation + + def rpr_to_act(self, rpr, observation): #why rpr and observation? we talked about this but can't remember + self.current_action = self.q.get_action(rpr) + + # Determine action set + if observation['HAS_FLAG']: + current_action_set = self.retreat_actions + else: + current_action_set = self.attack_actions + + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[self.current_action]['speed'], + 'course': current_action_set[self.current_action]['course'] + } + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' + } + + return action + + +if __name__ == '__main__': + # Create agents required + parser = argparse.ArgumentParser() + parser.add_argument('--model', default=DEFAULT_RUN_MODEL) + args = parser.parse_args() + + agents = [] + wait_for = [] + for i in [1, 2, 3]: + agents.append(Agent(f'agent_1{i}', f'drone_2{i}', args.model)) + wait_for.append(f'agent_1{i}') + + #console = ModelConsole() #where will this live? Do we care? Needs a full msg instead of msg.state + mgr = EpisodicManager(agents, 13, wait_for=wait_for) #13 for 10 full episodes... not ideal + mgr.start('runner') + diff --git a/examples/QTable/model/runner.py b/examples/QTable/model/runner.py index d5fe1b6..152c089 100755 --- a/examples/QTable/model/runner.py +++ b/examples/QTable/model/runner.py @@ -2,79 +2,73 @@ import argparse import time -from mivp_agent.manager import MissionManager +from mivp_agent.episodic_manager import EpisodicManager from mivp_agent.util.math import dist -from mivp_agent.util.display import ModelConsole +#from mivp_agent.util.display import ModelConsole from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG - from constants import DEFAULT_RUN_MODEL from model import load_model +class Agent: + def __init__(self, own_id, opponent_id, model) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q, self.attack_actions, self.retreat_actions = load_model(model) + self.current_action = None + + def id(self): + return self.own_id + + + def obs_to_rpr(self, observation): + model_representation = self.q.get_state( + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] + ) + #console.tick(observation) #needs full msg, obervation is an msg.state + + return model_representation + + def rpr_to_act(self, rpr, observation): + self.current_action = self.q.get_action(rpr) + + # Determine action set + if observation['HAS_FLAG']: + current_action_set = self.retreat_actions + else: + current_action_set = self.attack_actions -def run(args): - q, attack_actions, retreat_actions = load_model(args.model) - - with MissionManager('runner', log=False) as mgr: - print('Waiting for sim vehicle connections...') - while mgr.get_vehicle_count() < 1: - time.sleep(0.1) - # --------------------------------------- - # Part 1: Asserting simulation state - - last_state = None - current_action = None - current_action_set = None - console = ModelConsole() - - while True: - # Listen for state - msg = mgr.get_message() - while False: - print('-------------------------------------------') - print(f"({msg.vname}) {msg.state['HAS_FLAG']}") - print('-------------------------------------------') - msg.request_new() - msg = mgr.get_message() - - console.tick(msg) - - # Detect state transitions - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_X'], - msg.state['NODE_REPORTS'][args.enemy]['NAV_Y'], - msg.state['HAS_FLAG'] - ) - - # Detect state transition - if model_state != last_state: - current_action = q.get_action(model_state) - last_state = model_state - - # Determine action set - if msg.state['HAS_FLAG']: - current_action_set = retreat_actions - else: - current_action_set = attack_actions - - # Construct instruction for BHV_Agent - action = { - 'speed': current_action_set[current_action]['speed'], - 'course': current_action_set[current_action]['course'] + # Construct instruction for BHV_Agent + action = { + 'speed': current_action_set[self.current_action]['speed'], + 'course': current_action_set[self.current_action]['course'] + } + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' } - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - if flag_dist < 10: - action['posts']= { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } - - msg.act(action) + return action + if __name__ == '__main__': + # Create agents required parser = argparse.ArgumentParser() parser.add_argument('--model', default=DEFAULT_RUN_MODEL) - parser.add_argument('--enemy', default='drone_21') args = parser.parse_args() - run(args) \ No newline at end of file + + agents = [] + wait_for = [] + for i in [1, 2, 3]: + agents.append(Agent(f'agent_1{i}', f'drone_2{i}', args.model)) + + #console = ModelConsole() #where will this live? Do we care? Needs a full msg instead of msg.state + mgr = EpisodicManager(agents, 13, wait_for=wait_for) #13 for 10 full episodes... not ideal + mgr.start('runner') + diff --git a/examples/QTable/model/trainer.py b/examples/QTable/model/trainer.py index 551aa81..ca398fc 100755 --- a/examples/QTable/model/trainer.py +++ b/examples/QTable/model/trainer.py @@ -1,16 +1,9 @@ #!/usr/bin/env python3 -import os -import time -import wandb import argparse -from tqdm import tqdm - -from mivp_agent.manager import MissionManager -from mivp_agent.util.display import ModelConsole +from mivp_agent.episodic_manager import EpisodicManager from mivp_agent.util.math import dist -from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG - from wandb_key import WANDB_KEY +from mivp_agent.aquaticus.const import FIELD_BLUE_FLAG from constants import LEARNING_RATE, DISCOUNT, EPISODES from constants import FIELD_RESOLUTION from constants import EPSILON_START, EPSILON_DECAY_START, EPSILON_DECAY_AMT, EPSILON_DECAY_END @@ -24,27 +17,22 @@ 'agent_11': 'drone_21', 'agent_12': 'drone_22', 'agent_13': 'drone_23', - 'agent_14': 'drone_24', - 'agent_15': 'drone_25' + #'agent_14': 'drone_24', + #'agent_15': 'drone_25' } EXPECTED_VEHICLES = [key for key in VEHICLE_PAIRING] -class AgentData: - ''' - Used to encapsulate the data needed to run each indivdual - agent, track state / episode transitions, and output useful - information and statistics. - ''' - def __init__(self, vname, enemy): - self.vname = vname - self.enemy = enemy - - # For running of simulation - self.agent_episode_count = 0 - self.last_episode_num = None # Episode transitions - self.last_state = None # State transitions - self.had_flag = False # Capturing grab transitions for rewarding - self.current_action = None + +class Agent: + def __init__(self, own_id, opponent_id, q, config) -> None: + self.own_id = own_id + self.opponent_id = opponent_id + self.q = q + self.attack_actions = config['attack_actions'] + self.retreat_actions = config['retreat_actions'] + self.current_action = None + self.config = config + self.last_rpr = None # For debugging / output self.min_dist = None @@ -52,225 +40,126 @@ def __init__(self, vname, enemy): self.last_MOOS_time = None self.MOOS_deltas = [] self.grab_time = None - - def new_episode(self, last_num): - self.last_episode_num = last_num - self.agent_episode_count += 1 - self.last_state = None + def id(self): + ''' + Required hook for EpisodicManager to access agent ID + ''' + return self.own_id + + def reset_tracking_vars(self): + ''' + Reset per-episode tracking vars for this agent + ''' + #Functional + self.last_rpr = None self.had_flag = False self.current_action = None + #Debugging and I/O self.min_dist = None self.episode_reward = 0 self.last_MOOS_time = None self.MOOS_deltas.clear() self.grab_time = None - -def train(args, config, run_name): - agents = {} - with MissionManager('trainer', log=True, immediate_transition=False, id_suffix=run_name) as mgr: - # Create a directory for the model to save - model_save_dir = mgr.model_output_dir() - - # Setup model - q = QLearn( - lr=config['lr'], - gamma=config['gamma'], - action_space_size=config['action_space_size'], - field_res=config['field_res'], - verbose=args.debug, - save_dir=model_save_dir + def obs_to_rpr(self, observation): + ''' + Hook for EpisodicManager to convert an observation to python representation + Called on each new oberservation, used to detect state transitions + ''' + model_representation = self.q.get_state( + observation['NAV_X'], + observation['NAV_Y'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_X'], + observation['NODE_REPORTS'][self.opponent_id]['NAV_Y'], + observation['HAS_FLAG'] ) - print('Waiting for sim vehicle connection...') - mgr.wait_for(EXPECTED_VEHICLES) - - # Construct agent data object with enemy specified - agents = {} - for a in VEHICLE_PAIRING: - agents[a] = AgentData(a, VEHICLE_PAIRING[a]) - - # While all vehicle's pEpisodeManager are not PAUSED - print('Waiting for pEpisodeManager to enter PAUSED state...') - while not all(mgr.episode_state(vname) == 'PAUSED' for vname in agents): - msg = mgr.get_message() - if mgr.episode_state(msg.vname) != 'PAUSED': - # Instruct them to stop and pause - msg.stop() - else: - # O/w just keep up to date on their state - msg.request_new() - - # --------------------------------------- - # Part 2: Global state initalization - episode_count = 0 - epsilon = config['epsilon_start'] - progress_bar = tqdm(total=config['episodes'], desc='Training') - # Debugging - total_sim_time = 0 - - # Initalize the epidode numbers from pEpisodeManager - e_nums = mgr.episode_nums() - for a in e_nums: - agents[a].last_episode_num = e_nums[a] - - print('Running training...') - while episode_count < config['episodes']: - ''' - This loop handles an indivual msg from an indivual agent - data from each agent is used during training - ''' - # Listen for an agent's message & get it's data - msg = mgr.get_message() - # If pEpisodeManager is paused, start and continue to next agent - if msg.episode_state == 'PAUSED': - msg.mark_transition() # Initial state should be a transition - msg.start() - continue - agent_data = agents[msg.vname] - # Update debugging MOOS time - if agent_data.last_MOOS_time is not None: - agent_data.MOOS_deltas.append(msg.state['MOOS_TIME']-agent_data.last_MOOS_time) - agent_data.last_MOOS_time = msg.state['MOOS_TIME'] - - ''' - Part 1: Translate MOOS state to model's state representation - ''' - #print(msg.state['HAS_FLAG']) - model_state = q.get_state( - msg.state['NAV_X'], - msg.state['NAV_Y'], - msg.state['NODE_REPORTS'][agent_data.enemy]['NAV_X'], - msg.state['NODE_REPORTS'][agent_data.enemy]['NAV_Y'], - msg.state['HAS_FLAG'] + return model_representation + + def rpr_to_act(self, rpr, observation, em_report): + ''' + Hook for EpisodicManager to get next action based on current state representation + Called on each new state representation + Most functional code goes here + ''' + # Update previous state action pair + reward = self.config['reward_step'] + if observation['HAS_FLAG'] and not self.had_flag: + reward = self.config['reward_grab'] #Should this be += to also account for the step? + self.had_flag = True + # Capture time for debugging (time after grab) + self.grab_time = observation['MOOS_TIME'] + + if self.last_rpr is not None: + self.q.update_table( + self.last_rpr, + self.current_action, + reward, + rpr ) + self.episode_reward += self.config['reward_step'] #shouldn't this just be += reward? - # Detect discrete state transitions - if model_state != agent_data.last_state: - ''' - Part 2: Handle the ending of episodes - ''' - # Mark this state as a transition to record it to logs - msg.mark_transition() - - if msg.episode_report is None: - assert agent_data.agent_episode_count == 0 - elif msg.episode_report['DURATION'] < 2: - # Bad episode, don't use data - # Reset episode data and including 'last_episode_num' - agent_data.new_episode(msg.episode_report['NUM']) - elif msg.episode_report['NUM'] != agent_data.last_episode_num: - # Calculate reward based on pEpisodeManager's report - reward = config['reward_failure'] - if msg.episode_report['SUCCESS']: - reward = config['reward_capture'] - - # Apply this reward to the QTable - q.set_qvalue( - agent_data.last_state, - agent_data.current_action, - reward - ) - - # Update the total sim time - total_sim_time += msg.episode_report['DURATION'] - - # Construct report - report = { - 'episode_count': episode_count, - 'reward': agent_data.episode_reward+reward, - 'epsilon': round(epsilon, 3), - 'duration': round(msg.episode_report['DURATION'],2), - 'success': msg.episode_report['SUCCESS'], - 'min_dist': round(agent_data.min_dist, 2), - 'had_flag': agent_data.had_flag, - 'sim_time': round(total_sim_time, 2), - 'sim_days': round(total_sim_time / 86400, 2) - } - - if len(agent_data.MOOS_deltas) != 0: - report['avg_delta'] = round(sum(agent_data.MOOS_deltas)/len(agent_data.MOOS_deltas),2) - else: - report['avg_delta'] = 0.0 - - if agent_data.grab_time is not None: - report['post_grab_duration'] = round(msg.state['MOOS_TIME'] - agent_data.grab_time, 2) - else: - report['post_grab_duration'] = 0.0 - - # Log the report - if not args.no_wandb: - wandb.log(report) - tqdm.write(f'[{msg.vname}] ', end='') - tqdm.write(', '.join([f'{k}: {report[k]}' for k in report])) - - # Decay epsilon - if config['epsilon_decay_end'] >= episode_count >= config['epsilon_decay_start']: - epsilon -= config['epsilon_decay_amt'] - - # Reset episode data and including 'last_episode_num' - agent_data.new_episode(msg.episode_report['NUM']) - # Update global episode count - episode_count += 1 - progress_bar.update(1) - - # Save model if applicable - if episode_count % SAVE_EVERY == 0: - q.save( - config['attack_actions'], - config['retreat_actions'], - name=f'episode_{episode_count}' - ) + #################### NO #################### + global epsilon + ############# NO NO NO NO NO NO ############ + + self.current_action = q.get_action(rpr, e=epsilon) + self.last_rpr = rpr + + actions = self.attack_actions + if observation['HAS_FLAG']: # Use retreat actions if already grabbed and... retreating + actions = self.retreat_actions + action = actions[self.current_action].copy() # Copy out of reference paranoia + + flag_dist = abs(dist((observation['NAV_X'], observation['NAV_Y']), FIELD_BLUE_FLAG)) + + if flag_dist < 10: + action['posts']= { + 'FLAG_GRAB_REQUEST': f'vname={self.own_id}' + } + + return action + + def episode_end(self, rpr, observation, em_report): + ''' + Hook for EpisodicManager, called at the end of each episode + End rewards, model saving, etc. go here + ''' + reward = self.config['reward_failure'] + if em_report['success']: + reward = self.config['reward_capture'] + + # Apply this reward to the QTable + if self.last_rpr is not None: + self.q.set_qvalue( + self.last_rpr, + self.current_action, + reward + ) + else: + print("last_rpr was None") - ''' - Part 3: Handle updating actions / qtable in new states - ''' + completed_episodes = em_report['completed_episodes'] - # Update previous state action pair - reward = config['reward_step'] - if msg.state['HAS_FLAG'] and not agent_data.had_flag: - reward = config['reward_grab'] - agent_data.had_flag = True - # Capture time for debugging (time after grab) - agent_data.grab_time = msg.state['MOOS_TIME'] + #################### NO #################### + global epsilon + ############# NO NO NO NO NO NO ############ - if agent_data.last_state is not None: - q.update_table( - agent_data.last_state, - agent_data.current_action, - reward, - model_state - ) - agent_data.episode_reward += config['reward_step'] - - # Update tracking data - agent_data.current_action = q.get_action(model_state, e=epsilon) - agent_data.last_state = model_state - - ''' - Part 4: Even when agent is not in new state, keep preforming - the action that was calcualted on the when the state transitioned - ''' - actions = config['attack_actions'] - if msg.state['HAS_FLAG']: # Use retreat actions if already grabbed and... retreating - actions = config['retreat_actions'] - action = actions[agent_data.current_action].copy() # Copy out of reference paranoia - - flag_dist = abs(dist((msg.state['NAV_X'], msg.state['NAV_Y']), FIELD_BLUE_FLAG)) - # If this agent can grab the flag, do so - if flag_dist < 10: - action['posts'] = { - 'FLAG_GRAB_REQUEST': f'vname={msg.vname}' - } + # Decay epsilon + if self.config['epsilon_decay_end'] >= completed_episodes >= self.config['epsilon_decay_start']: + epsilon -= self.config['epsilon_decay_amt'] - # Send action - msg.act(action) + self.reset_tracking_vars() - # Debugging stuff - if agent_data.min_dist is None or agent_data.min_dist > flag_dist: - agent_data.min_dist = flag_dist + # Save model if applicable + if completed_episodes % config['save_every'] == 0: + q.save( + config['attack_actions'], + config['retreat_actions'], + name=f'episode_{completed_episodes}' + ) if __name__ == '__main__': @@ -301,12 +190,24 @@ def train(args, config, run_name): 'reward_capture': REWARD_CAPTURE, 'reward_failure': REWARD_FAILURE, 'reward_step': REWARD_STEP, + 'save_every': SAVE_EVERY, } - if args.no_wandb: - train(args, config, None) - else: - wandb.login(key=WANDB_KEY) - with wandb.init(project='mivp_agent_qtable', config=config): - config = wandb.config - train(args, config, f'{wandb.run.name}') \ No newline at end of file + # Setup model + q = QLearn( + lr=config['lr'], + gamma=config['gamma'], + action_space_size=config['action_space_size'], + field_res=config['field_res'], + verbose=args.debug, + save_dir=SAVE_DIR + ) + epsilon = config['epsilon_start'] + + agents = [] + wait_for = [] + for key in VEHICLE_PAIRING: + agents.append(Agent(key, VEHICLE_PAIRING[key], q, config)) + + mgr = EpisodicManager(agents, EPISODES, wait_for=wait_for) + mgr.start('trainer') \ No newline at end of file diff --git a/examples/QTable/run.sh b/examples/QTable/run.sh index e78a8a4..e791702 100755 --- a/examples/QTable/run.sh +++ b/examples/QTable/run.sh @@ -30,8 +30,8 @@ echo "Launching simulation..." # Give time for simulation to startup sleep 5 -# Start trainer with any arguments passed to this script +# Start runner with any arguments passed to this script echo "Launching runner..." -./model/runner.py "$@" +./model/new_runner.py "$@" cd $PREVIOUS_WD \ No newline at end of file diff --git a/examples/QTable/scripts/gui_launch.sh b/examples/QTable/scripts/gui_launch.sh index 52987bf..8179d80 100755 --- a/examples/QTable/scripts/gui_launch.sh +++ b/examples/QTable/scripts/gui_launch.sh @@ -8,9 +8,17 @@ TIME_WARP="4" cd ../mission/heron # Launch a agents ./launch_heron.sh red agent 11 --color=orange $TIME_WARP > /dev/null & + ./launch_heron.sh red agent 12 --color=green $TIME_WARP > /dev/null & + ./launch_heron.sh red agent 13 --color=purple $TIME_WARP > /dev/null & + #./launch_heron.sh red agent 14 --color=gray $TIME_WARP > /dev/null & + #./launch_heron.sh red agent 15 --color=yellow $TIME_WARP > /dev/null & # Launch a blue drone ./launch_heron.sh blue drone 21 --behavior=DEFEND --color=orange $TIME_WARP > /dev/null & + ./launch_heron.sh blue drone 22 --behavior=DEFEND --color=green $TIME_WARP > /dev/null & + ./launch_heron.sh blue drone 23 --behavior=DEFEND --color=purple $TIME_WARP > /dev/null & + #./launch_heron.sh blue drone 24 --behavior=DEFEND --color=gray $LOGGING $TIME_WARP > /dev/null & + #./launch_heron.sh blue drone 25 --behavior=DEFEND --color=yellow $LOGGING $TIME_WARP > /dev/null & cd .. cd shoreside diff --git a/src/python_module/src/mivp_agent/episodic_manager.py b/src/python_module/src/mivp_agent/episodic_manager.py new file mode 100644 index 0000000..19dcd47 --- /dev/null +++ b/src/python_module/src/mivp_agent/episodic_manager.py @@ -0,0 +1,97 @@ + +from mivp_agent.manager import MissionManager + + +class AgentData: + def __init__(self, vname) -> None: + self.vname = vname + + ''' + Used to recognize if the new episode 'NUM' is a new one. When we have new episodes we must increment the EpisodeManager episode count. + ''' + self.last_episode = None + + ''' + General idea: only call `rpr_to_act` when we reach a new state. Need to store the first action calculated for the state. + ''' + self.current_action = None + + ''' + To identify new states and perform action lookup / calculation from the model. + ''' + self.last_rpr = None + + +class EpisodicManager: + def __init__(self, agents, episodes, wait_for=None) -> None: + ''' + SETUP AGENTS + ''' + self.agents = agents + self.agent_data = {} + + # Combine the agent's id and any additional wait fors + self.wait_for = wait_for + + if self.wait_for is None: + self.wait_for = [] + # Setup things specific to each agent + + for a in self.agents: + # Make sure they are waited for + self.wait_for.append(a.id()) + # Make a agent state for them + self.agent_data[a.id()] = AgentData(a.id()) + + ''' + SETUP EPISODE TRACKING + ''' + self.episodes = episodes + self.completed_episode = 0 + + def _build_report(self): + report = { + 'completed_episodes': self.completed_episode, + } + + return report + + def start(self, task, log=True): + with MissionManager(task, log=log) as mgr: + mgr.wait_for(self.wait_for) + + while self.completed_episode < self.episodes: + msg = mgr.get_message() + + # Find agent in list... + for a in self.agents: + if msg.vname == a.id(): + data = self.agent_data[a.id()] + + # Probably always + rpr = a.obs_to_rpr(msg.state) #state nomenclature still lurking + #rpr = a.obs_to_rpr(msg) # full msg for console.tick + + if data.last_rpr != rpr: + msg.mark_transition() + em_report = self._build_report() + # still need state here bc rpr_to_act expects obs + data.current_action = a.rpr_to_act(rpr, msg.state, em_report) + + # Update episode count if applicable + if data.last_episode != msg.episode_report['NUM']: + if data.last_episode is not None: + # update the global episode count # if not first + self.completed_episode += 1 + data.last_episode = msg.episode_report['NUM'] + em_report = self._build_report() + em_report['success'] = msg.episode_report['SUCCESS'] + a.episode_end(rpr, msg.state, em_report) + + + # track data + data.last_rpr = rpr + + ################################################ + # Importantly, actually do shit # + msg.act(data.current_action) diff --git a/src/python_module/test/test_all.py b/src/python_module/test/test_all.py index 00c087a..90f3313 100755 --- a/src/python_module/test/test_all.py +++ b/src/python_module/test/test_all.py @@ -15,6 +15,7 @@ import test_bridge import test_log import test_manager +import test_episodic_manager import test_data_structures import test_proto import test_consumer @@ -36,6 +37,7 @@ suite.addTest(unittest.makeSuite(test_consumer.TestConsumer)) suite.addTest(unittest.makeSuite(test_manager.TestManagerCore)) suite.addTest(unittest.makeSuite(test_manager.TestManagerLogger)) + suite.addTest(unittest.makeSuite(test_episodic_manager.TestEpisodicManager)) suite.addTest(unittest.makeSuite(test_data_structures.TestLimitedHistory)) suite.addTest(unittest.makeSuite(test_proto.TestLogger)) diff --git a/src/python_module/test/test_episodic_manager.py b/src/python_module/test/test_episodic_manager.py new file mode 100644 index 0000000..03f0864 --- /dev/null +++ b/src/python_module/test/test_episodic_manager.py @@ -0,0 +1,126 @@ +import unittest +import time +from threading import Thread + +from mivp_agent.bridge import ModelBridgeClient +from mivp_agent.episodic_manager import EpisodicManager +from mivp_agent.const import KEY_ID, KEY_EPISODE_MGR_REPORT, KEY_EPISODE_MGR_STATE + +DUMMY_STATE = { + KEY_ID: 'felix', + 'MOOS_TIME': 16923.012, + 'NAV_X': 98.0, + 'NAV_Y': 40.0, + 'NAV_HEADING': 180, + KEY_EPISODE_MGR_REPORT: 'NUM=0,DURATION=60.57,SUCCESS=false,WILL_PAUSE=false', + KEY_EPISODE_MGR_STATE: 'PAUSED' +} + +class FakeAgent: + def __init__(self, id): + self._id = id + self.states = [] + + def id(self): + return self._id + + def obs_to_rpr(self, state): + self.states.append(state) + + def rpr_to_act(self, rpr, state): + pass + +class TestEpisodicManager(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.agents = { + 'misha': FakeAgent('misha'), + 'mike': FakeAgent('mike') + } + + # For each agent create a client + self.clients = {} + for name in self.agents: + self.clients[name] = ModelBridgeClient() + + # Get list of agent names for episodic manager + agnt_lst = [agent for _, agent in self.agents.items()] + self.mgr = EpisodicManager(agnt_lst, 100) + + # Start manager in another thread + self.mgr_thread = Thread(target=self.mgr.start, args=('tester',)) + self.mgr_thread.start() + + # Connect clients to server started by the manager + all_connected = False + while not all_connected: + all_connected = True + + for c in self.clients: + if not self.clients[c].is_connected(): + self.clients[c].connect() + all_connected = False + + def tearDown(self) -> None: + super().tearDown() + if self.mgr.is_running(): + self.mgr.stop() + + self.mgr_thread.join() + + time.sleep(0.1) + + for c in self.clients: + self.clients[c].close() + + def test_constructor(self): + # Test failure with no args + self.assertRaises(TypeError, EpisodicManager) + + agents = [] + agents.append(FakeAgent('joe')) + agents.append(FakeAgent('carter')) + + # Test failure with with no episodes + self.assertRaises(TypeError, EpisodicManager, agents) + + # Test success + EpisodicManager(agents, 100) + + def test_stop(self): + # Sanity check + self.assertTrue(self.mgr.is_running()) + + # Give the stop signal + self.mgr.stop() + + # Allow time for release then assert + time.sleep(0.2) + self.assertFalse(self.mgr.is_running()) + + def test_routing(self): + # Setup two states to be sent + misha_state = DUMMY_STATE.copy() + misha_state[KEY_ID] = 'misha' + misha_state['NAV_X'] = 1.0 + mike_state = DUMMY_STATE.copy() + mike_state[KEY_ID] = 'mike' + mike_state['NAV_X'] = 0.0 + + # Send states + self.clients['misha'].send_state(misha_state) + self.clients['mike'].send_state(mike_state) + + # Allow time for propagation + time.sleep(0.2) + + # Check that both received one state + self.assertEqual(len(self.agents['misha'].states), 1) + self.assertEqual(len(self.agents['mike'].states), 1) + + # Check that they received the proper state + self.assertEqual(self.agents['misha'].states[0]['NAV_X'], 1.0) + self.assertEqual(self.agents['mike'].states[0]['NAV_X'], 0.0) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file