-
Notifications
You must be signed in to change notification settings - Fork 381
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
env(rjy): add ising model env (#782)
* env(rjy): add ising model env * fix(rjy): modify ising for mean field RL * fix(rjy): try to fix reward problem * fix(rjy): fix the multi-agent reward * polish(rjy): fix dqn net init * polish(rjy): fix format * polish(rjy): norm eval_episode_return * polish(rjy): polish ising model env * fix(rjy): fix subprocess manager * fix(rjy): fixed reward compatibility * polish(rjy): polish according to comments * polish(rjy): add replay for ising * polish(rjy): and ising replay
- Loading branch information
Showing
17 changed files
with
707 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from easydict import EasyDict | ||
from ding.utils import set_pkg_seed | ||
|
||
obs_shape = 4 | ||
action_shape = 2 | ||
num_agents = 100 | ||
dim_spin = 2 | ||
agent_view_sight = 1 | ||
|
||
ising_mfq_config = dict( | ||
exp_name='ising_mfq_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
num_agents=num_agents, | ||
dim_spin=dim_spin, | ||
agent_view_sight=agent_view_sight, | ||
manager=dict(shared_memory=False, ), | ||
), | ||
policy=dict( | ||
cuda=True, | ||
priority=False, | ||
model=dict( | ||
obs_shape=obs_shape + action_shape, # for we will concat the pre_action_prob into obs | ||
action_shape=action_shape, | ||
encoder_hidden_size_list=[128, 128, 512], | ||
init_bias=0, | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=32, | ||
learning_rate=0.0001, | ||
target_update_freq=500, | ||
), | ||
collect=dict(n_sample=96, ), | ||
eval=dict(evaluator=dict(eval_freq=1000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
ising_mfq_config = EasyDict(ising_mfq_config) | ||
main_config = ising_mfq_config | ||
ising_mfq_create_config = dict( | ||
env=dict( | ||
type='ising_model', | ||
import_names=['dizoo.ising_env.envs.ising_model_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
ising_mfq_create_config = EasyDict(ising_mfq_create_config) | ||
create_config = ising_mfq_create_config | ||
|
||
if __name__ == '__main__': | ||
# or you can enter `ding -m serial -c ising_mfq_config.py -s 0` | ||
from ding.entry import serial_pipeline | ||
seed = 1 | ||
serial_pipeline((main_config, create_config), seed=seed, max_env_step=5e4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from dizoo.ising_env.config.ising_mfq_config import main_config, create_config | ||
from ding.entry import eval | ||
|
||
|
||
def main(): | ||
main_config.env.collector_env_num = 1 | ||
main_config.env.evaluator_env_num = 1 | ||
main_config.env.n_evaluator_episode = 1 | ||
ckpt_path = './ckpt_best.pth.tar' | ||
replay_path = './replay_videos' | ||
eval((main_config, create_config), seed=1, load_path=ckpt_path, replay_path=replay_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ising_model_env import IsingModelEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
|
||
from dizoo.ising_env.envs.ising_model.multiagent.core import IsingWorld, IsingAgent | ||
|
||
|
||
class Scenario(): | ||
|
||
def _calc_mask(self, agent, shape_size): | ||
# compute the neighbour mask for each agent | ||
if agent.view_sight == -1: | ||
# fully observed | ||
agent.spin_mask += 1 | ||
elif agent.view_sight == 0: | ||
# observe itself | ||
agent.spin_mask[agent.state.id] = 1 | ||
elif agent.view_sight > 0: | ||
# observe neighbours | ||
delta = list(range(-int(agent.view_sight), int(agent.view_sight) + 1, 1)) | ||
delta.remove(0) # agent itself is not counted as neighbour of itself | ||
for dt in delta: | ||
row = agent.state.p_pos[0] | ||
col = agent.state.p_pos[1] | ||
row_dt = row + dt | ||
col_dt = col + dt | ||
if row_dt in range(0, shape_size): | ||
agent.spin_mask[agent.state.id + shape_size * dt] = 1 | ||
if col_dt in range(0, shape_size): | ||
agent.spin_mask[agent.state.id + dt] = 1 | ||
|
||
# the graph is cyclic, most left and most right are neighbours | ||
if agent.state.p_pos[0] < agent.view_sight: | ||
tar = shape_size - (np.array(range(0, int(agent.view_sight - agent.state.p_pos[0]), 1)) + 1) | ||
tar = tar * shape_size + agent.state.p_pos[1] | ||
agent.spin_mask[tar] = [1] * len(tar) | ||
|
||
if agent.state.p_pos[1] < agent.view_sight: | ||
tar = shape_size - (np.array(range(0, int(agent.view_sight - agent.state.p_pos[1]), 1)) + 1) | ||
tar = agent.state.p_pos[0] * shape_size + tar | ||
agent.spin_mask[tar] = [1] * len(tar) | ||
|
||
if agent.state.p_pos[0] >= shape_size - agent.view_sight: | ||
tar = np.array(range(0, int(agent.view_sight - (shape_size - 1 - agent.state.p_pos[0])), 1)) | ||
tar = tar * shape_size + agent.state.p_pos[1] | ||
agent.spin_mask[tar] = [1] * len(tar) | ||
|
||
if agent.state.p_pos[1] >= shape_size - agent.view_sight: | ||
tar = np.array(range(0, int(agent.view_sight - (shape_size - 1 - agent.state.p_pos[1])), 1)) | ||
tar = agent.state.p_pos[0] * shape_size + tar | ||
agent.spin_mask[tar] = [1] * len(tar) | ||
|
||
def make_world(self, num_agents=100, agent_view=1): | ||
world = IsingWorld() | ||
world.agent_view_sight = agent_view | ||
world.dim_spin = 2 | ||
world.dim_pos = 2 | ||
world.n_agents = num_agents | ||
world.shape_size = int(np.ceil(np.power(num_agents, 1.0 / world.dim_pos))) | ||
world.global_state = np.zeros((world.shape_size, ) * world.dim_pos) | ||
# assume 0 external magnetic field | ||
world.field = np.zeros((world.shape_size, ) * world.dim_pos) | ||
|
||
world.agents = [IsingAgent(view_sight=world.agent_view_sight) for i in range(num_agents)] | ||
|
||
# make initial conditions | ||
self.reset_world(world) | ||
|
||
return world | ||
|
||
def reset_world(self, world): | ||
|
||
world_mat = np.array( | ||
range(np.power(world.shape_size, world.dim_pos))). \ | ||
reshape((world.shape_size,) * world.dim_pos) | ||
# init agent state and global state | ||
for i, agent in enumerate(world.agents): | ||
agent.name = 'agent %d' % i | ||
agent.color = np.array([0.35, 0.35, 0.85]) | ||
agent.state.id = i | ||
agent.state.p_pos = np.where(world_mat == i) | ||
agent.state.spin = np.random.choice(world.dim_spin) | ||
agent.spin_mask = np.zeros(world.n_agents) | ||
|
||
assert world.dim_pos == 2, "cyclic neighbour only support 2D now" | ||
self._calc_mask(agent, world.shape_size) | ||
world.global_state[agent.state.p_pos] = agent.state.spin | ||
|
||
n_ups = np.count_nonzero(world.global_state.flatten()) | ||
n_downs = world.n_agents - n_ups | ||
world.order_param = abs(n_ups - n_downs) / (world.n_agents + 0.0) | ||
|
||
def reward(self, agent, world): | ||
# turn the state into -1/1 for easy computing | ||
world.global_state[np.where(world.global_state == 0)] = -1 | ||
|
||
mask_display = agent.spin_mask.reshape((int(np.sqrt(world.n_agents)), -1)) | ||
|
||
local_reward = - 0.5 * world.global_state[agent.state.p_pos] \ | ||
* np.sum(world.global_state.flatten() * agent.spin_mask) | ||
|
||
world.global_state[np.where(world.global_state == -1)] = 0 | ||
return -local_reward | ||
|
||
def observation(self, agent, world): | ||
# get positions of all entities in this agent's reference frame | ||
# agent state is updated in the world.step() function already | ||
# update the changes of the world | ||
|
||
# return the neighbour state | ||
return world.global_state.flatten()[np.where(agent.spin_mask == 1)] | ||
|
||
def done(self, agent, world): | ||
if world.order_param == 1.0: | ||
return True | ||
return False |
Oops, something went wrong.