Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Luo1 committed Apr 12, 2024
1 parent d63eb55 commit 7e67df2
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 81 deletions.
2 changes: 2 additions & 0 deletions brax/training/agents/apg/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from brax.training.types import PRNGKey
import flax
from flax import linen
from flax.linen.initializers import orthogonal


@flax.struct.dataclass
Expand Down Expand Up @@ -65,6 +66,7 @@ def make_apg_networks(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=hidden_layer_sizes, activation=activation,
kernel_init = orthogonal(0.01),
layer_norm=layer_norm)
return APGNetworks(
policy_network=policy_network,
Expand Down
162 changes: 86 additions & 76 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ class TrainingState:
def _unpmap(v):
return jax.tree_util.tree_map(lambda x: x[0], v)


def train(
environment: Union[envs_v1.Env, envs.Env],
timesteps: int,
episode_length: int,
policy_updates: int,
horizon_length: int = 32,
num_envs: int = 1,
num_evals: int = 1,
Expand All @@ -66,7 +67,8 @@ def train(
learning_rate: float = 1e-4,
adam_b: list = [0.7, 0.95],
use_schedule: bool = True,
schedule_decay: float = 0.995,
use_float64: bool = True,
schedule_decay: float = 0.997,
seed: int = 0,
max_gradient_norm: float = 1e9,
normalize_observations: bool = False,
Expand Down Expand Up @@ -95,9 +97,9 @@ def train(
process_id, local_device_count, local_devices_to_use)
device_count = local_devices_to_use * process_count

num_updates = jnp.ceil(timesteps / (num_envs * horizon_length)) # Total # of policy updates
num_updates = policy_updates
num_evals_after_init = max(num_evals - 1, 1)
updates_per_epoch = jnp.ceil(num_updates / (num_evals_after_init))
updates_per_epoch = jnp.round(num_updates / (num_evals_after_init))

assert num_envs % device_count == 0
env = environment
Expand Down Expand Up @@ -125,6 +127,7 @@ def train(
)

reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

normalize = lambda x, y: x
if normalize_observations:
Expand All @@ -148,8 +151,9 @@ def train(
)

def scramble_times(state, key):
state.info['steps'] = jnp.round(jax.random.uniform(key, (local_devices_to_use, num_envs,), maxval=episode_length))

state.info['steps'] = jnp.round(
jax.random.uniform(key, (local_devices_to_use, num_envs,),
maxval=episode_length))
return state

def env_step(
Expand All @@ -161,77 +165,65 @@ def env_step(
key, key_sample = jax.random.split(key)
actions = policy(env_state.obs, key_sample)[0]
nstate = env.step(env_state, actions)
# if truncation_length is not None: # TODO: Remove
# nstate = jax.lax.cond(
# jnp.mod(step_index + 1, truncation_length) == 0.,
# jax.lax.stop_gradient, lambda x: x, nstate)

return (nstate, key), (nstate.reward, env_state.obs)

def loss(env_state, policy_params, normalizer_params, key):
key_reset, key_scan = jax.random.split(key)
# env_state = env.reset(
# jax.random.split(key_reset, num_envs // process_count))
def loss(policy_params, normalizer_params, env_state, key):
f = functools.partial(
env_step, policy=make_policy((normalizer_params, policy_params)))
(state_h, _), (rewards,
obs) = jax.lax.scan(f, (env_state, key_scan),
obs) = jax.lax.scan(f, (env_state, key),
(jnp.arange(horizon_length // action_repeat)))
return -jnp.mean(rewards), obs, state_h
jax.debug.print("{x}", x=jnp.mean(rewards))

return -jnp.mean(rewards), (obs, state_h)

loss_grad = jax.grad(loss, has_aux=True)

gradient_update_fn = gradients.gradient_update_fn(
loss, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)
# loss_grad = jax.grad(loss, has_aux=True)
def clip_by_global_norm(updates):
g_norm = optax.global_norm(updates)
trigger = g_norm < max_gradient_norm
return jax.tree_util.tree_map(
lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm),
updates)

def minibatch_step(
carry, epoch_step_index: int):
(optimizer_state, normalizer_params,
params, key, state) = carry
key, key_loss = jax.random.split(key)
(_, obs, state_h), params, optimizer_state = gradient_update_fn(
state,
params,
normalizer_params,
key_loss,
optimizer_state=optimizer_state)
policy_params, key, state) = carry

key, key_grad = jax.random.split(key)
grad, (obs, state_h) = loss_grad(policy_params,
normalizer_params,
state,
key_grad)

grad = clip_by_global_norm(grad)
grad = jax.lax.pmean(grad, axis_name='i')
params_update, optimizer_state = optimizer.update(
grad, optimizer_state)
policy_params = optax.apply_updates(policy_params,
params_update)

normalizer_params = running_statistics.update(
normalizer_params, obs, pmap_axis_name=_PMAP_AXIS_NAME)

return (optimizer_state, normalizer_params, params, key, state_h), metrics

# def clip_by_global_norm(updates): # TODO: remove
# g_norm = optax.global_norm(updates)
# trigger = g_norm < max_gradient_norm
# return jax.tree_util.tree_map(
# lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm),
# updates)
metrics = {
'grad_norm': optax.global_norm(grad),
'params_norm': optax.global_norm(policy_params)
}

return (optimizer_state, normalizer_params, policy_params, key, state_h), metrics

def training_epoch(training_state: TrainingState, env_state: Union[envs.State, envs_v1.State], key: PRNGKey):
# key, key_epoch = jax.random.split(key)

(optimizer_state, normalizer_params,
policy_params, key, state_h), metrics = jax.lax.scan(
minibatch_step,
(optimizer_state, training_state.normalizer_params,
(training_state.optimizer_state, training_state.normalizer_params,
training_state.policy_params, key, env_state),
jnp.arange(updates_per_epoch))

# key, key_grad = jax.random.split(key)
# grad, obs = loss_grad(training_state.policy_params,
# training_state.normalizer_params, key_grad)
# grad = clip_by_global_norm(grad)
# grad = jax.lax.pmean(grad, axis_name='i')
# params_update, optimizer_state = optimizer.update(
# grad, training_state.optimizer_state)
# policy_params = optax.apply_updates(training_state.policy_params,
# params_update)


# metrics = {
# 'grad_norm': optax.global_norm(grad),
# 'params_norm': optax.global_norm(policy_params)
# }

return TrainingState(
optimizer_state=optimizer_state,
normalizer_params=normalizer_params,
Expand Down Expand Up @@ -266,11 +258,12 @@ def training_epoch_with_timing(training_state: TrainingState,
policy_params = apg_network.policy_network.init(global_key)
del global_key

dtype = 'float64' if use_float64 else 'float32'
training_state = TrainingState(
optimizer_state=optimizer.init(policy_params),
policy_params=policy_params,
normalizer_params=running_statistics.init_state(
specs.Array((env.observation_size,), jnp.dtype('float32'))))
specs.Array((env.observation_size,), jnp.dtype(dtype))))
training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])
Expand Down Expand Up @@ -298,35 +291,51 @@ def training_epoch_with_timing(training_state: TrainingState,

# Run initial eval
metrics = {}
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.policy_params)),
training_metrics={})
logging.info(metrics)
progress_fn(0, metrics)

init_key, scramble_key, local_key = jax.random.split(local_key, 3)
keys = jax.random.split(init_key, (local_devices_to_use, num_envs))
env_state = reset_fn(keys) # First batch axes for process, 2nd for envs.
env_state = scramble_times(env_state, scramble_key)
print(env_state.info['steps'])

# if process_id == 0 and num_evals > 1:
# metrics = evaluator.run_evaluation(
# _unpmap(
# (training_state.normalizer_params, training_state.policy_params)),
# training_metrics={})
# logging.info(metrics)
# progress_fn(0, metrics)

init_key, scramble_key, local_key = jax.random.split(local_key, 3)
init_key = jax.random.split(init_key, (local_devices_to_use, num_envs // process_count))
env_state = reset_fn(init_key)
env_state = scramble_times(env_state, scramble_key)
env_state = step_fn(env_state, jnp.zeros((local_devices_to_use, num_envs // process_count,
env.action_size))) # Prevent recompilation on the second epoch

epoch_key, local_key = jax.random.split(local_key)
epoch_key = jax.random.split(epoch_key, local_devices_to_use)

for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)

# optimization
from pathlib import Path
import pickle

file_name = f'checkpoint_{it}.pkl'
base_path = "/tmp/checkpoints/"
save_to = str(Path(
Path(base_path),
Path(file_name)))
algo_state = {'training_state': training_state}
pickle.dump(algo_state, open(save_to, "wb"))

(training_state, env_state,
training_metrics, keys) = training_epoch_with_timing(training_state, env_state, keys)
training_metrics, epoch_key) = training_epoch_with_timing(training_state, env_state, epoch_key)

if process_id == 0:
# Run evals.
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.policy_params)),
training_metrics)
logging.info(metrics)
progress_fn(it + 1, metrics)
# if process_id == 0:
# # Run evals.
# metrics = evaluator.run_evaluation(
# _unpmap(
# (training_state.normalizer_params, training_state.policy_params)),
# training_metrics)
# logging.info(metrics)
# progress_fn(it + 1, metrics)

# If there was no mistakes the training_state should still be identical on all
# devices.
Expand All @@ -335,3 +344,4 @@ def training_epoch_with_timing(training_state: TrainingState,
(training_state.normalizer_params, training_state.policy_params))
pmap.synchronize_hosts()
return (make_policy, params, metrics)

13 changes: 8 additions & 5 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MLP(linen.Module):
activate_final: bool = False
bias: bool = True
layer_norm: bool = False

@linen.compact
def __call__(self, data: jnp.ndarray):
hidden = data
Expand All @@ -55,8 +55,8 @@ def __call__(self, data: jnp.ndarray):
hidden)
if i != len(self.layer_sizes) - 1 or self.activate_final:
hidden = self.activation(hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
return hidden


Expand Down Expand Up @@ -89,12 +89,15 @@ def make_policy_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())
kernel_init=kernel_init,
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = preprocess_observations_fn(obs, processor_params)
Expand Down

0 comments on commit 7e67df2

Please sign in to comment.