Skip to content
This repository has been archived by the owner on Feb 6, 2023. It is now read-only.

Exceptional agents #11

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 59 additions & 13 deletions chainerrl_visualizer/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def launch_visualizer(agent, gymlike_env, action_meanings, log_dir='log_space',
if isinstance(gymlike_env, gym.Env):
modify_gym_env_render(gymlike_env)

compensate_agent_lacked_method(agent)
profile = inspect_agent(agent, gymlike_env, contains_rnn)

job_queue = Queue()
Expand Down Expand Up @@ -99,8 +100,65 @@ def prepare_log_directory(log_dir): # log_dir is assumed to be full path
return True


def compensate_agent_lacked_method(agent):
if not hasattr(agent, 'batch_states'):
agent.batch_states = chainerrl.misc.batch_states


def validate_agent_profile(profile):
if profile['distribution_type'] is None and profile['action_value_type'] is None:
raise Exception('Outputs of model do not contain ActionValue nor DistributionType')

if profile['action_value_type'] is not None \
and profile['action_value_type'] not in SUPPORTED_ACTION_VALUES:
raise Exception('ActionValue type {} is not supported for now'.format(
profile['action_value_type']))

if profile['distribution_type'] is not None \
and profile['distribution_type'] not in SUPPORTED_DISTRIBUTIONS:
raise Exception('Distribution type {} is not supported for now'.format(
profile['distribution_type']))


# workaround
def inspect_exceptional_agent(agent, gymlike_env, contains_rnn):
profile = {
'contains_recurrent_model': contains_rnn,
'state_value_returned': True,
'distribution_type': None,
'action_value_type': None,
}

obs = gymlike_env.reset()
policy = agent.policy

# workaround
if hasattr(agent, 'xp'):
xp = agent.xp
else:
xp = np

if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin):
with policy.state_kept():
dist = policy(agent.batch_states([obs], xp, agent.phi))
else:
dist = policy(agent.batch_states([obs], xp, agent.phi))

profile['distribution_type'] = type(dist).__name__

validate_agent_profile(profile)

return profile


# Create and return dict contains agent profile
def inspect_agent(agent, gymlike_env, contains_rnn):
# workaround
# These three agents are exceptional in that the other agents have `model` attribute
# and `model.__call__()` returns outputs of the model.
if type(agent).__name__ in ['TRPO', 'DDPG', 'PGT']:
return inspect_exceptional_agent(agent, gymlike_env, contains_rnn)

profile = {
'contains_recurrent_model': contains_rnn,
'state_value_returned': False,
Expand Down Expand Up @@ -142,18 +200,6 @@ def inspect_agent(agent, gymlike_env, contains_rnn):
raise Exception(
'Model output type of {} is not supported for now'.format(type(output).__name__))

# Validations
if profile['distribution_type'] is None and profile['action_value_type'] is None:
raise Exception('Outputs of model do not contain ActionValue nor DistributionType')

if profile['action_value_type'] is not None \
and profile['action_value_type'] not in SUPPORTED_ACTION_VALUES:
raise Exception('ActionValue type {} is not supported for now'.format(
profile['action_value_type']))

if profile['distribution_type'] is not None \
and profile['distribution_type'] not in SUPPORTED_DISTRIBUTIONS:
raise Exception('Distribution type {} is not supported for now'.format(
profile['distribution_type']))
validate_agent_profile(profile)

return profile
95 changes: 81 additions & 14 deletions chainerrl_visualizer/worker_jobs/rollout_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li
render_img_list[:] = [] # Clear the shared render images list

# workaround
if hasattr(agent, 'xp'):
xp = agent.xp
else:
xp = np
if not hasattr(agent, 'xp'):
agent.xp = np

log_fp = open(os.path.join(rollout_dir, ROLLOUT_LOG_FILE_NAME), 'a')
writer = jsonlines.Writer(log_fp)
Expand All @@ -50,17 +48,13 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li
obs_list.append(obs)
render_img_list.append(rendered)

if isinstance(agent, chainerrl.recurrent.RecurrentChainMixin):
with agent.model.state_kept():
outputs = agent.model(agent.batch_states([obs], xp, agent.phi))
# workaround
# These three agents are exceptional in that the other agents have `model` attribute
# and `model.__call__()` returns outputs of the model.
if type(agent).__name__ in ['TRPO', 'DDPG', 'PGT']:
obs, r, done, action, outputs = _step_exceptional_agent(agent, gymlike_env, obs)
else:
outputs = agent.model(agent.batch_states([obs], xp, agent.phi))

if not isinstance(outputs, tuple):
outputs = tuple((outputs,))

action = agent.act(obs)
obs, r, done, info = gymlike_env.step(action)
obs, r, done, action, outputs = _step_agent(agent, gymlike_env, obs)

log_entries = dict()
log_entries['step'] = t
Expand Down Expand Up @@ -133,6 +127,79 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li
raise Exception(error_msg)


def _step_agent(agent, gymlike_env, obs):
if isinstance(agent, chainerrl.recurrent.RecurrentChainMixin):
with agent.model.state_kept():
outputs = agent.model(agent.batch_states([obs], agent.xp, agent.phi))
else:
outputs = agent.model(agent.batch_states([obs], agent.xp, agent.phi))

if not isinstance(outputs, tuple):
outputs = tuple((outputs,))

action = agent.act(obs)
obs, r, done, _ = gymlike_env.step(action)

return obs, r, done, action, outputs


def _step_exceptional_agent(agent, gymlike_env, obs):
policy = agent.policy
agent_type = type(agent).__name__
b_state = agent.batch_states([obs], agent.xp, agent.phi)

if agent_type in ['DDPG', 'PGT']:
if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin):
with policy.state_kept():
action_dist = policy(b_state)
else:
action_dist = policy(b_state)

# workaround
# If `agent.act()` called when `agent.q_function` has LSTM,
# the params of the model will change. So, we have to directly get `action`
# from `action_dist`. `action` is needed for parameter of `q_function()`.
if agent_type == 'DDPG':
action = action_dist.sample()
else: # PGT
if agent.act_deterministically:
action = action_dist.most_probable
else:
action = action_dist.sample()

q_function = agent.q_function
if isinstance(q_function, chainerrl.recurrent.RecurrentChainMixin):
with q_function.state_kept():
q_value = q_function(b_state, action)
else:
q_value = q_function(b_state, action)

outputs = (action_dist, q_value)

elif agent_type == 'TRPO':
if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin):
with policy.state_kept():
action_dist = policy(b_state)
else:
action_dist = policy(b_state)

value_function = agent.vf
if isinstance(value_function, chainerrl.recurrent.RecurrentChainMixin):
with value_function.state_kept():
state_value = value_function(b_state)
else:
state_value = value_function(b_state)

outputs = (action_dist, state_value)
else:
raise Exception('{} is not one of the exceptional agent types'.format(agent_type))

action = agent.act(obs)
obs, r, done, _ = gymlike_env.step(action)

return obs, r, done, action, outputs


def _save_env_render(rendered, rollout_dir):
image = Image.fromarray(rendered)
image_path = os.path.join(rollout_dir, 'images', generate_random_string(11) + '.png')
Expand Down