diff --git a/brax/training/agents/sac/networks.py b/brax/training/agents/sac/networks.py index 7e8d58e9..dc50106a 100644 --- a/brax/training/agents/sac/networks.py +++ b/brax/training/agents/sac/networks.py @@ -56,7 +56,9 @@ def make_sac_networks( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: networks.ActivationFn = linen.relu) -> SACNetworks: + activation: networks.ActivationFn = linen.relu, + policy_network_layer_norm: bool = False, + q_network_layer_norm: bool = False) -> SACNetworks: """Make SAC networks.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -65,13 +67,15 @@ def make_sac_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=policy_network_layer_norm) q_network = networks.make_q_network( observation_size, action_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=q_network_layer_norm) return SACNetworks( policy_network=policy_network, q_network=q_network, diff --git a/brax/training/networks.py b/brax/training/networks.py index 76e9c12f..23e041aa 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -136,7 +136,8 @@ def make_q_network( .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu, - n_critics: int = 2) -> FeedForwardNetwork: + n_critics: int = 2, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a value network.""" class QModule(linen.Module): @@ -151,7 +152,8 @@ def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray): q = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, - kernel_init=jax.nn.initializers.lecun_uniform())( + kernel_init=jax.nn.initializers.lecun_uniform(), + layer_norm=layer_norm)( hidden) res.append(q) return jnp.concatenate(res, axis=-1) diff --git a/brax/v1/jumpy.py b/brax/v1/jumpy.py index 168d9c10..2c65ef23 100644 --- a/brax/v1/jumpy.py +++ b/brax/v1/jumpy.py @@ -43,7 +43,7 @@ def _in_jit() -> bool: if jax.__version_info__ <= (0, 4, 33): return 'DynamicJaxprTrace' in str( - jax.core.thread_local_state.trace_state.trace_stack + jax.core.thread_local_state.trace_state.trace_stack # type: ignore ) return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE() diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md index 7e5c6680..46b3124c 100644 --- a/docs/release-notes/next-release.md +++ b/docs/release-notes/next-release.md @@ -1,4 +1,5 @@ # Brax Release Notes * Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is. -* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0. \ No newline at end of file +* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0. +* Add `layer_norm` to `make_q_network` and set `layer_norm` to `True` in `make_sace_networks` Q Network.