-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_perimeter_defense.py
93 lines (85 loc) · 3.2 KB
/
run_perimeter_defense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import maPDenv
import numpy as np
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', type=str, default='maPDefense-v0')
parser.add_argument('--render', help='whether to render', type=int, default=0)
parser.add_argument('--record', help='whether to record', type=int, default=0)
parser.add_argument('--ros', help='whether to use ROS', type=int, default=0)
parser.add_argument('--nb_agents', help='the number of agents', type=int, default=2)
parser.add_argument('--nb_targets', help='the number of targets', type=int, default=2)
parser.add_argument('--log_dir', help='a path to a directory to log your data', type=str, default='.')
parser.add_argument('--map', type=str, default="empty")
args = parser.parse_args()
# @profile
def main():
env = maPDenv.make(args.env,
render=args.render,
record=args.record,
ros=args.ros,
directory=args.log_dir,
map_name=args.map,
num_agents=args.nb_agents,
num_targets=args.nb_targets,
is_training=False,
)
rewards = []
nlogdetcov = []
intruders = []
action_dict = {}
done = {'__all__':False}
step = 0
obs = env.reset()
# See below why this check is needed for training or eval loop
while not done['__all__']:
step += 1
if args.render:
env.render()
for agent_id, o in obs.items():
action_dict[agent_id] = env.action_space.sample()
obs, rew, done, info = env.step(action_dict)
print(rew['__all__'])
rewards.append(rew['__all__'])
nlogdetcov.append(info['mean_nlogdetcov'])
intruders.append(info['num_intruders'])
print("Total episode reward : %.2f, Sum of neg logdet of the target belief covs : %.2f, Total num of intruders : %d, Ep_len : %d"%(np.sum(rewards),np.sum(nlogdetcov),np.sum(intruders), step))
if __name__ == "__main__":
main()
"""
To use line_profiler
add @profile before a function to profile
kernprof -l run_ma_example.py --env setTracking-v3 --nb_agents 4 --nb_targets 4 --render 0
python -m line_profiler run_ma_example.py.lprof
Examples:
>>> env = MyMultiAgentEnv()
>>> obs = env.reset()
>>> print(obs)
{
"agent_0": [2.4, 1.6],
"agent_1": [3.4, -3.2],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"agent_0": 1, "agent_1": 0,
})
>>> print(rew)
{
"agent_0": 3,
"agent_1": -1,
"__all__": 2,
}
>>> print(done)
#Due to gym wrapper, done at TimeLimit is bool, True.
#During episode, it is a dict so..
#While done is a dict keep running
{
"agent_0": False, # agent_0 is still running
"agent_1": True, # agent_1 is done
"__all__": False, # the env is not done
}
>>> print(info)
{
"agent_0": {}, # info for agent_0
"agent_1": {}, # info for agent_1
}
"""