-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(gry): refactor reward model #636
base: main
Are you sure you want to change the base?
Changes from 25 commits
c372c07
6718e4a
be7039a
a4de466
d615c14
7a8ec6e
55c7be8
ff60716
6b80392
25d49b5
c081ff0
f1218cd
d9060c2
179182a
d067731
29f0d55
4ec0bd3
800f090
c64b5c7
660af32
1b0d579
b4e81dd
eddc80d
6e2b867
e25d265
97634dc
f099cac
0c48c08
594d619
58a2bff
e9db652
822d7a4
be03aa9
d3ce3e2
4c19aa3
0cc2149
5b4e4cc
ff4de47
9e63ef1
ca2e2db
8716afe
9036141
6b9754a
a52a1c0
a5c7989
d631237
f42d131
edff260
6ab66e1
cb0c627
016fbb3
cf50148
919c01b
a1d0b3a
0a0af3c
e310b4c
92dc227
a4f364d
1f06dec
a547b3b
97da5c6
774b2a4
b78e36c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,12 +18,14 @@ | |
|
||
|
||
def serial_pipeline_reward_model_offpolicy( | ||
input_cfg: Union[str, Tuple[dict, dict]], | ||
seed: int = 0, | ||
env_setting: Optional[List[Any]] = None, | ||
model: Optional[torch.nn.Module] = None, | ||
max_train_iter: Optional[int] = int(1e10), | ||
max_env_step: Optional[int] = int(1e10), | ||
input_cfg: Union[str, Tuple[dict, dict]], | ||
seed: int = 0, | ||
env_setting: Optional[List[Any]] = None, | ||
model: Optional[torch.nn.Module] = None, | ||
max_train_iter: Optional[int] = int(1e10), | ||
max_env_step: Optional[int] = int(1e10), | ||
cooptrain_reward: Optional[bool] = True, | ||
pretrain_reward: Optional[bool] = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comments for new arguments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
) -> 'Policy': # noqa | ||
""" | ||
Overview: | ||
|
@@ -78,6 +80,8 @@ def serial_pipeline_reward_model_offpolicy( | |
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode | ||
) | ||
reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger) | ||
if pretrain_reward: | ||
reward_model.train() | ||
|
||
# ========== | ||
# Main loop | ||
|
@@ -108,10 +112,11 @@ def serial_pipeline_reward_model_offpolicy( | |
# collect data for reward_model training | ||
reward_model.collect_data(new_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add if if cooptrain_reward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) | ||
# update reward_model | ||
reward_model.train() | ||
# update reward_model, when you want to train reward_model inloop | ||
if cooptrain_reward: | ||
reward_model.train() | ||
# clear buffer per fix iters to make sure replay buffer's data count isn't too few. | ||
if count % cfg.reward_model.clear_buffer_per_iters == 0: | ||
if hasattr(cfg.reward_model, 'clear_buffer_per_iters') and count % cfg.reward_model.clear_buffer_per_iters == 0: | ||
reward_model.clear_data() | ||
# Learn policy from collected data | ||
for i in range(cfg.policy.learn.update_per_collect): | ||
|
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cooptrain_reward
->joint_train_reward_model
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed