-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathconfig.yaml
executable file
·160 lines (144 loc) · 5.74 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
seed: 42
output_dir: ${hydra:runtime.output_dir}
domains: mujoco_metaworld # domains to train on
wb_tag: "default" # wandb tag
log_interval: 10 # how many steps before logging to wandb
script_name: "" # log the running script
pretrained_dir: "" # pretrained model directory
parallel_eval: False # use ray to do parallel evaluation
slurm_job_id: "" # the slurm job id for logging purposes
user_id: "" # the machine user id for logging purposes
epoch_size: 10
total_num_traj: 0
train_time: 0
cross_validate: False
cross_validate_eps_num: 20
cross_validate_freq: 100
save_intermedia_models: False
comment: ""
defaults:
- _self_
- env: mujoco_metaworld
# dataset config
dataset:
_target_: hpt.dataset.local_traj_dataset.LocalTrajDataset
horizon: 1 # horizon for each dataset sample. not used
val_ratio: 0.1 # the train-validation ratio
pad_after: 0 # padding after the episode
episode_cnt: 10000 # total episodes by default
step_cnt: 100000 # total data transitions
data_augmentation: False # data augmentation
use_disk: False # use disk instead of memory to store the data
pad_before: 0 # padding before the episode
data_ratio: 1 # only use a fraction of data
action_horizon: 8 # observation: (observation + action) is action horizon
observation_horizon: 1 # before observation horizon is observation
dataset_postfix: "_traj${dataset.episode_cnt}" # postfix for the dataset
precompute_feat: True # precompute features using pretrained models for stems
image_encoder: 'resnet' # which encoder to use as the pretrained model
dataset_encoder_postfix: "_${dataset.image_encoder}" # another postfix
use_multiview: False # use multiple camera views
normalize_state: True # whether to normalize the states in datasets
regenerate: False # regenerate data
action_multiple_horizon: True # multiply action dimensions by horizon
random_mask_obs: True # whether to randomize observation input length
data_augment_ratio: 1 # add data augmentation to the images
proprioception_expand: False # expand proprioception to use multiple tokens
proprioception_expand_dim: 32 # expand proprioception dimensions
# trunk transformer config
network:
_target_: hpt.models.policy.Policy
embed_dim: 128
num_blocks: 16 # num of blocks in the trunk transformer
num_heads: 8 # num of heads in the trunk transformer
drop_path: 0.1 # drop path in the trunk transformer
use_modality_embedding: True # add trainable modality position tokens
use_domain_embedding: False # whether to add domain-specific trainable parameters
observation_horizon: ${dataset.horizon} # the observation history
action_horizon: 1 # open loop action steps. <= the dataset action horizons
token_postprocessing: "mean" # maxpool or meanpool the tokens
cross_stem_attention: True # use cross attention to combine state and action
weight_init_style: 'pytorch' # weight init
no_trunk: False # ignore trunk
finetune_encoder: False # whether to finetune encoders
# stem network for different modalities
stem:
modalities: ['image', 'state'] # 'language'
modality_embed_dim: ${network.embed_dim}
normalize_state: ${dataset.normalize_state} # normalize state vectors
state_embedding_dim: 1 # dimension of positional encoding for state
cross_attention: True # whether to use cross attention or not
precompute_feat: True # whether to use precomputed features. if not, will finetune.
image_encoder: ${dataset.image_encoder} # what image encoder to use
crossattn_dim_head: 64 # for cross attention modules
crossattn_heads: 8 # number of heads in cross attention
crossattn_modality_dropout: 0.1 # the dropout ratio for cross attention
num_blocks: 1 # number of blocks for stem transformer's cross and self attention
observation_horizon: ${dataset.observation_horizon} # observation horizon
masked_autoencoding: False # random mask encoding and then reconstruction
random_horizon_masking: True # whether to randomize observation input length
add_pos_embedding_to_state: False # positional embedding for the state
# standardize token sizes for each modality
crossattn_latent:
image: 16
state: 16
# language: 8
image:
_target_: hpt.models.policy_stem.MLP
input_dim: 512
output_dim: ${network.embed_dim}
widths: [128]
num_of_copy: 1
# each item is a token
state:
_target_: hpt.models.policy_stem.MLP
input_dim: ${stem.state_embedding_dim} # ovewrite based on the dataset
output_dim: ${network.embed_dim}
widths: [128]
# head network
head:
_target_: hpt.models.policy_head.MLP
input_dim: ${network.embed_dim}
tanh_end: True # normalized action output
output_dim: -1 # overwrite based on dataset
widths: [256, 128]
normalize_action: ${head.tanh_end}
dropout: True
# self-explanatory torch config
dataloader:
batch_size: 256
num_workers: 1
pin_memory: True
persistent_workers: True
shuffle: True
drop_last: False
val_dataloader:
batch_size: 256
num_workers: 1
shuffle: False
pin_memory: True
persistent_workers: True
drop_last: False
optimizer:
_target_: torch.optim.AdamW
lr: 1.0e-5 # 1e-4
eps: 1.0e-7
weight_decay: 1e-4
optimizer_misc:
nontrunk_lr_scale: 1.
warmup_lr:
lr: 1e-10
step: 1000 # first 1000 iterations
lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: ${train.total_epochs}
eta_min: 1e-6
# training config
train:
total_epochs: 500 # maximum training epochs before termination. usually set as maximum
total_iters: 20000 # maximum training steps before termination
epoch_iters: 1000 # training steps in each epoch
validation_iters: 100 # maximum iterations for validation
pretrained_dir: "" # pretrained model path for testing
freeze_trunk: True # whether to freeze the trunk during finetuning
wandb_pretrained_dir: "" # use models pretrained on wandb