Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691493168
Change-Id: I41833f49a9ea1d8f1ba0ac7886663d8c6b847572
  • Loading branch information
Brax Team authored and btaba committed Oct 30, 2024
1 parent c87dcfc commit bf616ce
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
10 changes: 7 additions & 3 deletions brax/training/agents/sac/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def make_sac_networks(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: networks.ActivationFn = linen.relu) -> SACNetworks:
activation: networks.ActivationFn = linen.relu,
policy_network_layer_norm: bool = False,
q_network_layer_norm: bool = False) -> SACNetworks:
"""Make SAC networks."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
Expand All @@ -65,13 +67,15 @@ def make_sac_networks(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=policy_network_layer_norm)
q_network = networks.make_q_network(
observation_size,
action_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=q_network_layer_norm)
return SACNetworks(
policy_network=policy_network,
q_network=q_network,
Expand Down
6 changes: 4 additions & 2 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def make_q_network(
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu,
n_critics: int = 2) -> FeedForwardNetwork:
n_critics: int = 2,
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a value network."""

class QModule(linen.Module):
Expand All @@ -151,7 +152,8 @@ def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray):
q = MLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())(
kernel_init=jax.nn.initializers.lecun_uniform(),
layer_norm=layer_norm)(
hidden)
res.append(q)
return jnp.concatenate(res, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion brax/v1/jumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _in_jit() -> bool:

if jax.__version_info__ <= (0, 4, 33):
return 'DynamicJaxprTrace' in str(
jax.core.thread_local_state.trace_state.trace_stack
jax.core.thread_local_state.trace_state.trace_stack # type: ignore
)

return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
Expand Down
3 changes: 2 additions & 1 deletion docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Brax Release Notes

* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is.
* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0.
* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0.
* Add `layer_norm` to `make_q_network` and set `layer_norm` to `True` in `make_sace_networks` Q Network.

0 comments on commit bf616ce

Please sign in to comment.