Skip to content

Commit

Permalink
minor gym refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kzl committed Jun 11, 2021
1 parent 5fc73c1 commit 046702e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
25 changes: 14 additions & 11 deletions gym/decision_transformer/models/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
**kwargs
)

# note: the only difference between this GPT2Model and the default Huggingface version
# is that the positional embeddings are removed (since we'll add those ourselves)
self.transformer = GPT2Model(config)

self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
Expand Down Expand Up @@ -87,14 +89,14 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_
)
x = transformer_outputs['last_hidden_state']

# reshape x so that the second dimension corresponds to
# predicting returns (0), actions (1), or states (2)
# reshape x so that the second dimension corresponds to the original
# returns (0), actions (1), or states (2); i.e. x[:,1,t] is the token for a_t
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

# get predictions
return_preds = self.predict_return(x[:,0])
state_preds = self.predict_state(x[:,2])
action_preds = self.predict_action(x[:,1])
return_preds = self.predict_return(x[:,2]) # predict next return given state and action
state_preds = self.predict_state(x[:,2]) # predict next state given state and action
action_preds = self.predict_action(x[:,1]) # predict next action given state

return state_preds, action_preds, return_preds

Expand All @@ -112,26 +114,27 @@ def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwarg
returns_to_go = returns_to_go[:,-self.max_length:]
timesteps = timesteps[:,-self.max_length:]

# padding
# pad all tokens to sequence length
attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat(
[torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
dim=1).to(dtype=torch.float32)
actions = torch.cat(
[torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
device=actions.device), actions],
dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
dim=1).to(dtype=torch.float32)
timesteps = torch.cat(
[torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
dim=1
).to(dtype=torch.long)

actions = torch.cat(
[torch.zeros((actions.shape[0], self.max_length-actions.shape[1], self.act_dim), device=actions.device), actions],
dim=1).to(dtype=torch.float32)
else:
attention_mask = None

_, action_preds, return_preds = self.forward(states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
_, action_preds, return_preds = self.forward(
states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)

return action_preds[0,-1]
32 changes: 14 additions & 18 deletions gym/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def experiment(
if env_name == 'hopper':
env = gym.make('Hopper-v3')
max_ep_len = 1000
env_targets = [3600, 1800]
scale = 1000.
env_targets = [3600, 1800] # evaluation conditioning targets
scale = 1000. # normalization for rewards/returns
elif env_name == 'halfcheetah':
env = gym.make('HalfCheetah-v3')
max_ep_len = 1000
Expand All @@ -65,21 +65,16 @@ def experiment(
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

small = variant.get('small', False)
dataset_path = f'{env_name}-{dataset}-v2.pkl'
if small and env_name == 'reacher2d':
dataset_path = f'smallest1t-{dataset_path}'
elif small:
dataset_path = f'smallest-{dataset_path}'
dataset_path = f'data/{dataset_path}'

# load dataset
dataset_path = f'data/{env_name}-{dataset}-v2.pkl'
with open(dataset_path, 'rb') as f:
trajectories = pickle.load(f)

# save all path information into separate lists
mode = variant.get('mode', 'normal')
states, traj_lens, returns = [], [], []
for path in trajectories:
if mode == 'delayed':
if mode == 'delayed': # delayed: all rewards moved to end of trajectory
path['rewards'][-1] = path['rewards'].sum()
path['rewards'][:-1] = 0.
states.append(path['observations'])
Expand Down Expand Up @@ -108,7 +103,6 @@ def experiment(
# only train on top pct_traj trajectories (for %BC experiment)
num_timesteps = max(int(pct_traj*num_timesteps), 1)
sorted_inds = np.argsort(returns) # lowest to highest

num_trajectories = 1
timesteps = traj_lens[sorted_inds[-1]]
ind = len(trajectories) - 2
Expand All @@ -117,6 +111,8 @@ def experiment(
num_trajectories += 1
ind -= 1
sorted_inds = sorted_inds[-num_trajectories:]

# used to reweight sampling so we sample according to timesteps instead of trajectories
p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

def get_batch(batch_size=256, max_len=K):
Expand All @@ -132,6 +128,7 @@ def get_batch(batch_size=256, max_len=K):
traj = trajectories[int(sorted_inds[batch_inds[i]])]
si = random.randint(0, traj['rewards'].shape[0] - 1)

# get sequences from dataset
s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
Expand All @@ -145,21 +142,22 @@ def get_batch(batch_size=256, max_len=K):
if rtg[-1].shape[1] <= s[-1].shape[1]:
rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

# padding and state + reward normalization
tlen = s[-1].shape[1]
s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
s[-1] = (s[-1] - state_mean) / state_std
a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1)
rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)
d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) / scale
rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

Expand Down Expand Up @@ -207,7 +205,6 @@ def fn(model):
}
return fn


if model_type == 'dt':
model = DecisionTransformer(
state_dim=state_dim,
Expand Down Expand Up @@ -277,7 +274,6 @@ def fn(model):
)
# wandb.watch(model) # wandb has some bug

max_return = -1e6
for iter in range(variant['max_iters']):
outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True)
if log_to_wandb:
Expand All @@ -287,8 +283,8 @@ def fn(model):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--dataset', type=str, default='medium')
parser.add_argument('--mode', type=str, default='normal')
parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert
parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse
parser.add_argument('--K', type=int, default=20)
parser.add_argument('--pct_traj', type=float, default=1.)
parser.add_argument('--batch_size', type=int, default=64)
Expand Down

0 comments on commit 046702e

Please sign in to comment.