-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
29 lines (22 loc) · 1.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gymnasium as gym
import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from epsim.core import SHARE
from epsim.envs.epsp import EPSP
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.ppo_mask import MaskablePPO
import hydra
def mask_fn(env: gym.Env) -> np.ndarray:
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.world._masks[SHARE.DISPATCH_CODE]
@hydra.main(config_path="./config", config_name="args", version_base="1.3")
def main(cfg: "DictConfig"): # noqa: F821
vec_env = make_vec_env(lambda :ActionMasker(EPSP(render_mode=None,args=cfg),mask_fn), n_envs=8)
model = MaskablePPO("MlpPolicy", vec_env, verbose=1,policy_kwargs=dict(net_arch=[256, 256, 256]))
model.learn(total_timesteps=200000)
model.save("models/ppo_mask")
if __name__ == "__main__":
main()