From fe1fddc3a6c0fbbcec6807b07cf2e5034c4f3ebc Mon Sep 17 00:00:00 2001 From: "adrien.bolling" Date: Tue, 2 Jul 2024 12:58:19 +0200 Subject: [PATCH] Added the missing n_sample_weights arg for the log_all_multi_policy_metrics tracking in the cooperative MOMAPPO. --- momaland/learning/cooperative_momappo/continuous_momappo.py | 3 +++ momaland/learning/cooperative_momappo/discrete_momappo.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/momaland/learning/cooperative_momappo/continuous_momappo.py b/momaland/learning/cooperative_momappo/continuous_momappo.py index 1dbaa46c..56c05787 100644 --- a/momaland/learning/cooperative_momappo/continuous_momappo.py +++ b/momaland/learning/cooperative_momappo/continuous_momappo.py @@ -97,6 +97,8 @@ def parse_args(): help="the activation function for the neural networks") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="whether to anneal the learning rate linearly") + parser.add_argument("--n-sample-weights", type=int, default=10, + help="number of weights to sample for EUM and MUL computation") args = parser.parse_args() # fmt: on @@ -594,6 +596,7 @@ def _env_step(runner_state): hv_ref_point=np.array(args.ref_point), reward_dim=reward_dim, global_step=weight_number * args.timesteps_per_weight, + n_sample_weights=args.n_sample_weights, ) if args.save_policies: save_actor(actor_state, w, args) diff --git a/momaland/learning/cooperative_momappo/discrete_momappo.py b/momaland/learning/cooperative_momappo/discrete_momappo.py index 3d7a969c..fa9738a4 100644 --- a/momaland/learning/cooperative_momappo/discrete_momappo.py +++ b/momaland/learning/cooperative_momappo/discrete_momappo.py @@ -98,6 +98,8 @@ def parse_args(): help="the activation function for the neural networks") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="whether to anneal the learning rate linearly") + parser.add_argument("--n-sample-weights", type=int, default=10, + help="number of weights to sample for EUM and MUL computation") args = parser.parse_args() # fmt: on @@ -590,6 +592,7 @@ def _env_step(runner_state): hv_ref_point=np.array(args.ref_point), reward_dim=reward_dim, global_step=weight_number * args.timesteps_per_weight, + n_sample_weights=args.n_sample_weights, ) if args.save_policies: save_actor(actor_state, w, args)