-
Notifications
You must be signed in to change notification settings - Fork 1
/
her.py
67 lines (56 loc) · 2.03 KB
/
her.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from stable_baselines3 import DDPG, DQN, SAC, TD3, HerReplayBuffer
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
import gym
import wandb
from wandb.keras import WandbCallback
import her_wrappers
# wandb.init(project="fetch-reach-her")
conf = wandb.config
conf.learn_timesteps = 50000
conf.n_sampled_goal = 4
conf.batch_size = 20
conf.max_episode_length = 50
conf.cnn = False
conf.reward_type='dense'
if conf.cnn:
env = her_wrappers.Robotics('FetchReach-v1', reward_type=conf.reward_type)
else:
env = gym.make('FetchReach-v1', reward_type=conf.reward_type)
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
# If True the HER transitions will get sampled online
online_sampling = False
# Time limit for the episodes
max_episode_length = conf.max_episode_length
# Initialize the model
model = DDPG(
"CnnPolicy" if conf.cnn else "MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
# Parameters for HER
replay_buffer_kwargs=dict(
n_sampled_goal=conf.n_sampled_goal,
goal_selection_strategy=goal_selection_strategy,
online_sampling=online_sampling,
max_episode_length=max_episode_length
),
verbose=1,
)
# Evaluate the model every 1000 steps
eval_callback = EvalCallback(eval_env=env, eval_freq=500, n_eval_episodes=10, log_path='logdir-her/dense/', verbose=True)
# Train the model
model.learn(conf.learn_timesteps, callback=eval_callback)
model.save("./her_fetchreach")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
model = DDPG.load('./her_fetchreach', env=env)
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
print(reward)
if done:
obs = env.reset()