Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(gry): refactor reward model #636

Open
wants to merge 63 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
c372c07
refactor network and red reward model
ruoyuGao Apr 5, 2023
6718e4a
create reward model utils
ruoyuGao Apr 5, 2023
be7039a
polish network and reward model utils, provide test for them
ruoyuGao Apr 6, 2023
a4de466
refactor network for two method: learn and forward
ruoyuGao Apr 10, 2023
d615c14
Merge branch 'main' into ruoyugao
ruoyuGao Apr 10, 2023
7a8ec6e
refactor rnd
ruoyuGao Apr 11, 2023
55c7be8
refactor gail
ruoyuGao Apr 11, 2023
ff60716
fix gail for unit test
ruoyuGao Apr 11, 2023
6b80392
refactor icm
ruoyuGao Apr 12, 2023
25d49b5
fix wrong unit test in test_reward_model_utils
ruoyuGao Apr 12, 2023
c081ff0
refactor gcl and pwil
ruoyuGao Apr 13, 2023
f1218cd
refactor pdeil
ruoyuGao Apr 13, 2023
d9060c2
add hidden_size_list to gail
ruoyuGao Apr 13, 2023
179182a
change gail test for new config
ruoyuGao Apr 13, 2023
d067731
refactor trex network
ruoyuGao Apr 14, 2023
29f0d55
fix style and wrong import
ruoyuGao Apr 14, 2023
4ec0bd3
fix style for trex
ruoyuGao Apr 14, 2023
800f090
Merge branch 'main' into ruoyugao
ruoyuGao Apr 14, 2023
c64b5c7
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Apr 14, 2023
660af32
fix unit test for trex onppo
ruoyuGao Apr 17, 2023
1b0d579
Merge branch 'main' into ruoyugao
ruoyuGao Apr 21, 2023
b4e81dd
refactor ngu and provide cartpole config file
ruoyuGao Apr 21, 2023
eddc80d
change reward entry
ruoyuGao Apr 26, 2023
6e2b867
change trex entry to new entry, combine old trex test to new test
ruoyuGao Apr 26, 2023
e25d265
Merge branch 'main' into ruoyugao
ruoyuGao Apr 28, 2023
97634dc
refactor trex config file
ruoyuGao Apr 28, 2023
f099cac
refactor trex config file
ruoyuGao Apr 28, 2023
0c48c08
refactor trex config file
ruoyuGao Apr 28, 2023
594d619
add gail to new reward entry
ruoyuGao May 3, 2023
58a2bff
remove preferenced based irl entry(used for trex, drex before)
ruoyuGao May 3, 2023
e9db652
Merge branch 'main' into ruoyugao
ruoyuGao May 3, 2023
822d7a4
remove unuse code in gcl
ruoyuGao May 3, 2023
be03aa9
change clear data from pipeline to RM && add ngu to new entry
ruoyuGao May 4, 2023
d3ce3e2
remove ngu old entry
ruoyuGao May 4, 2023
4c19aa3
fix env pool test bug
ruoyuGao May 4, 2023
0cc2149
add drex to new entry
ruoyuGao May 4, 2023
5b4e4cc
fix unit test for trex and gail
ruoyuGao May 4, 2023
ff4de47
fix style
ruoyuGao May 4, 2023
9e63ef1
fix style for drex unittest
ruoyuGao May 5, 2023
ca2e2db
fix drex unittest
ruoyuGao May 5, 2023
8716afe
fix bug in minigrid env
ruoyuGao May 5, 2023
9036141
add explain for rm utils
ruoyuGao May 6, 2023
6b9754a
move RM unittest into one file
ruoyuGao May 6, 2023
a52a1c0
Merge branch 'main' into ruoyugao
ruoyuGao May 6, 2023
a5c7989
add drex config
ruoyuGao May 8, 2023
d631237
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao May 8, 2023
f42d131
fix ngu wrapper bug in minigrid
ruoyuGao May 9, 2023
edff260
fix ngu wrapper bug in minigrid
ruoyuGao May 9, 2023
6ab66e1
Merge branch 'main' into ruoyugao
ruoyuGao May 10, 2023
cb0c627
refactor gcl, add it to reward entry
ruoyuGao May 22, 2023
016fbb3
refactor gcl config and bash format other config
ruoyuGao May 22, 2023
cf50148
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao May 22, 2023
919c01b
fix bug for test, remove wrong comment
ruoyuGao May 22, 2023
a1d0b3a
polish code for ngu, drex, base rm and entry
ruoyuGao Jun 6, 2023
0a0af3c
Merge branch 'main' into ruoyugao
ruoyuGao Jun 6, 2023
e310b4c
polish code for all rm
ruoyuGao Jun 6, 2023
92dc227
fix style for ngu
ruoyuGao Jun 6, 2023
a4f364d
polish comment for config files
ruoyuGao Jun 6, 2023
1f06dec
add gcl unit test
ruoyuGao Jun 9, 2023
a547b3b
polish RM
ruoyuGao Jun 19, 2023
97da5c6
fix style for rnd and icm
ruoyuGao Jun 20, 2023
774b2a4
fix style for rnd and icm
ruoyuGao Jun 20, 2023
b78e36c
fix style for icm
ruoyuGao Jun 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions ding/entry/serial_entry_reward_model_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@


def serial_pipeline_reward_model_offpolicy(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
cooptrain_reward: Optional[bool] = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cooptrain_reward -> joint_train_reward_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

pretrain_reward: Optional[bool] = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for new arguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretrain_reward -> pretrain_reward_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

) -> 'Policy': # noqa
"""
Overview:
Expand Down Expand Up @@ -78,6 +80,8 @@ def serial_pipeline_reward_model_offpolicy(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
if pretrain_reward:
reward_model.train()

# ==========
# Main loop
Expand Down Expand Up @@ -108,10 +112,11 @@ def serial_pipeline_reward_model_offpolicy(
# collect data for reward_model training
reward_model.collect_data(new_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add if if cooptrain_reward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# update reward_model
reward_model.train()
# update reward_model, when you want to train reward_model inloop
if cooptrain_reward:
reward_model.train()
# clear buffer per fix iters to make sure replay buffer's data count isn't too few.
if count % cfg.reward_model.clear_buffer_per_iters == 0:
if hasattr(cfg.reward_model, 'clear_buffer_per_iters') and count % cfg.reward_model.clear_buffer_per_iters == 0:
reward_model.clear_data()
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
Expand Down
25 changes: 15 additions & 10 deletions ding/entry/serial_entry_reward_model_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@


def serial_pipeline_reward_model_onpolicy(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
cooptrain_reward: Optional[bool] = True,
pretrain_reward: Optional[bool] = False,
) -> 'Policy': # noqa
"""
Overview:
Expand Down Expand Up @@ -78,7 +80,8 @@ def serial_pipeline_reward_model_onpolicy(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)

if pretrain_reward:
reward_model.train()
# ==========
# Main loop
# ==========
Expand Down Expand Up @@ -106,10 +109,12 @@ def serial_pipeline_reward_model_onpolicy(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data_count += len(new_data)
# collect data for reward_model training
reward_model.collect_data(new_data)
if cooptrain_reward:
reward_model.collect_data(new_data)
# update reward_model
reward_model.train()
if count % cfg.reward_model.clear_buffer_per_iters == 0:
if cooptrain_reward:
reward_model.train()
if hasattr(cfg.reward_model, 'clear_buffer_per_iters') and count % cfg.reward_model.clear_buffer_per_iters == 0:
reward_model.clear_data()
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
Expand Down
62 changes: 0 additions & 62 deletions ding/entry/tests/test_serial_entry_preference_based_irl.py

This file was deleted.

This file was deleted.

60 changes: 45 additions & 15 deletions ding/entry/tests/test_serial_entry_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,54 @@
from copy import deepcopy

from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
serial_pipeline_reward_model_onpolicy
from ding.entry.application_entry_trex_collect_data import trex_collecting_data

cfg = [
{
'type': 'pdeil',
"alpha": 0.5,
"discrete_action": False
},
{
}, {
'type': 'gail',
'input_size': 5,
'hidden_size': 64,
'hidden_size_list': [64],
'batch_size': 64,
},
{
}, {
'type': 'pwil',
's_size': 4,
'a_size': 2,
'sample_size': 500,
},
{
}, {
'type': 'red',
'sample_size': 5000,
'input_size': 5,
'hidden_size': 64,
'obs_shape': 4,
'action_shape': 1,
'hidden_size_list': [64, 1],
'update_per_collect': 200,
'batch_size': 128,
},
}, {
'type': 'trex',
'exp_name': 'cartpole_trex_offppo_seed0',
'min_snippet_length': 5,
'max_snippet_length': 100,
'checkpoint_min': 0,
'checkpoint_max': 6,
'checkpoint_step': 6,
'learning_rate': 1e-5,
'update_per_collect': 1,
'expert_model_path': 'cartpole_ppo_offpolicy_seed0',
'data_path': 'abs data path',
'hidden_size_list': [512, 64, 1],
'obs_shape': 4,
'action_shape': 2,
}
]


Expand All @@ -51,9 +67,15 @@ def test_irl(reward_model_config):
expert_data_path = 'expert_data.pkl'
state_dict = expert_policy.collect_mode.state_dict()
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_demo_data(
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
if reward_model_config.type == 'trex':
trex_config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
trex_config[0].reward_model = reward_model_config
args = EasyDict({'cfg': deepcopy(trex_config), 'seed': 0, 'device': 'cpu'})
trex_collecting_data(args=args)
else:
collect_demo_data(
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
# irl + rl training
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
Expand All @@ -64,10 +86,18 @@ def test_irl(reward_model_config):
reward_model_config['expert_data_path'] = expert_data_path
cp_cartpole_dqn_config.reward_model = reward_model_config
cp_cartpole_dqn_config.policy.collect.n_sample = 128
cooptrain_reward = True
pretrain_reward = False
if reward_model_config.type == 'trex':
cooptrain_reward = False
pretrain_reward = True
serial_pipeline_reward_model_offpolicy(
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config),
seed=0,
max_train_iter=2,
pretrain_reward=pretrain_reward,
cooptrain_reward=cooptrain_reward
)

os.popen("rm -rf ckpt_* log expert_data.pkl")


Expand Down
2 changes: 2 additions & 0 deletions ding/reward_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from .icm_reward_model import ICMRewardModel
from .network import RepresentationNetwork, RNDNetwork, REDNetwork, GAILNetwork, ICMNetwork, GCLNetwork, TREXNetwork
from .reword_model_utils import concat_state_action_pairs, combine_intrinsic_exterinsic_reward, obs_norm, collect_states
Loading