diff --git a/gym/decision_transformer/models/decision_transformer.py b/gym/decision_transformer/models/decision_transformer.py index f10b4cb0..d4985a56 100644 --- a/gym/decision_transformer/models/decision_transformer.py +++ b/gym/decision_transformer/models/decision_transformer.py @@ -33,6 +33,8 @@ def __init__( **kwargs ) + # note: the only difference between this GPT2Model and the default Huggingface version + # is that the positional embeddings are removed (since we'll add those ourselves) self.transformer = GPT2Model(config) self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) @@ -87,14 +89,14 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ ) x = transformer_outputs['last_hidden_state'] - # reshape x so that the second dimension corresponds to - # predicting returns (0), actions (1), or states (2) + # reshape x so that the second dimension corresponds to the original + # returns (0), actions (1), or states (2); i.e. x[:,1,t] is the token for a_t x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) # get predictions - return_preds = self.predict_return(x[:,0]) - state_preds = self.predict_state(x[:,2]) - action_preds = self.predict_action(x[:,1]) + return_preds = self.predict_return(x[:,2]) # predict next return given state and action + state_preds = self.predict_state(x[:,2]) # predict next state given state and action + action_preds = self.predict_action(x[:,1]) # predict next action given state return state_preds, action_preds, return_preds @@ -112,12 +114,16 @@ def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwarg returns_to_go = returns_to_go[:,-self.max_length:] timesteps = timesteps[:,-self.max_length:] - # padding + # pad all tokens to sequence length attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) states = torch.cat( [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], dim=1).to(dtype=torch.float32) + actions = torch.cat( + [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), + device=actions.device), actions], + dim=1).to(dtype=torch.float32) returns_to_go = torch.cat( [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], dim=1).to(dtype=torch.float32) @@ -125,13 +131,10 @@ def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwarg [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], dim=1 ).to(dtype=torch.long) - - actions = torch.cat( - [torch.zeros((actions.shape[0], self.max_length-actions.shape[1], self.act_dim), device=actions.device), actions], - dim=1).to(dtype=torch.float32) else: attention_mask = None - _, action_preds, return_preds = self.forward(states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) + _, action_preds, return_preds = self.forward( + states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) return action_preds[0,-1] diff --git a/gym/experiment.py b/gym/experiment.py index 60416fa3..ecd86b85 100644 --- a/gym/experiment.py +++ b/gym/experiment.py @@ -38,8 +38,8 @@ def experiment( if env_name == 'hopper': env = gym.make('Hopper-v3') max_ep_len = 1000 - env_targets = [3600, 1800] - scale = 1000. + env_targets = [3600, 1800] # evaluation conditioning targets + scale = 1000. # normalization for rewards/returns elif env_name == 'halfcheetah': env = gym.make('HalfCheetah-v3') max_ep_len = 1000 @@ -65,21 +65,16 @@ def experiment( state_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] - small = variant.get('small', False) - dataset_path = f'{env_name}-{dataset}-v2.pkl' - if small and env_name == 'reacher2d': - dataset_path = f'smallest1t-{dataset_path}' - elif small: - dataset_path = f'smallest-{dataset_path}' - dataset_path = f'data/{dataset_path}' - + # load dataset + dataset_path = f'data/{env_name}-{dataset}-v2.pkl' with open(dataset_path, 'rb') as f: trajectories = pickle.load(f) + # save all path information into separate lists mode = variant.get('mode', 'normal') states, traj_lens, returns = [], [], [] for path in trajectories: - if mode == 'delayed': + if mode == 'delayed': # delayed: all rewards moved to end of trajectory path['rewards'][-1] = path['rewards'].sum() path['rewards'][:-1] = 0. states.append(path['observations']) @@ -108,7 +103,6 @@ def experiment( # only train on top pct_traj trajectories (for %BC experiment) num_timesteps = max(int(pct_traj*num_timesteps), 1) sorted_inds = np.argsort(returns) # lowest to highest - num_trajectories = 1 timesteps = traj_lens[sorted_inds[-1]] ind = len(trajectories) - 2 @@ -117,6 +111,8 @@ def experiment( num_trajectories += 1 ind -= 1 sorted_inds = sorted_inds[-num_trajectories:] + + # used to reweight sampling so we sample according to timesteps instead of trajectories p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) def get_batch(batch_size=256, max_len=K): @@ -132,6 +128,7 @@ def get_batch(batch_size=256, max_len=K): traj = trajectories[int(sorted_inds[batch_inds[i]])] si = random.randint(0, traj['rewards'].shape[0] - 1) + # get sequences from dataset s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim)) a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim)) r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1)) @@ -145,13 +142,14 @@ def get_batch(batch_size=256, max_len=K): if rtg[-1].shape[1] <= s[-1].shape[1]: rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1) + # padding and state + reward normalization tlen = s[-1].shape[1] s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1) s[-1] = (s[-1] - state_mean) / state_std a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1) r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1) d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1) - rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) + rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1) mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)) @@ -159,7 +157,7 @@ def get_batch(batch_size=256, max_len=K): a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device) r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device) d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device) - rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) / scale + rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device) mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device) @@ -207,7 +205,6 @@ def fn(model): } return fn - if model_type == 'dt': model = DecisionTransformer( state_dim=state_dim, @@ -277,7 +274,6 @@ def fn(model): ) # wandb.watch(model) # wandb has some bug - max_return = -1e6 for iter in range(variant['max_iters']): outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) if log_to_wandb: @@ -287,8 +283,8 @@ def fn(model): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='hopper') - parser.add_argument('--dataset', type=str, default='medium') - parser.add_argument('--mode', type=str, default='normal') + parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert + parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse parser.add_argument('--K', type=int, default=20) parser.add_argument('--pct_traj', type=float, default=1.) parser.add_argument('--batch_size', type=int, default=64)