diff --git a/init2winit/model_lib/dlrm.py b/init2winit/model_lib/dlrm.py index f12cf037..6313dbb7 100644 --- a/init2winit/model_lib/dlrm.py +++ b/init2winit/model_lib/dlrm.py @@ -132,7 +132,7 @@ def __call__(self, x, train): bot_mlp_input = nn.Dense( dense_dim, kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), )(bot_mlp_input) bot_mlp_input = activation_fn(bot_mlp_input) bot_mlp_input = normalizer_layer()(bot_mlp_input) @@ -143,7 +143,7 @@ def __call__(self, x, train): base_init_fn = jnn.initializers.uniform(scale=1.0) if self.embedding_init_multiplier is None: - embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) + embedding_init_multiplier = 1 / self.vocab_size**0.5 else: embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. @@ -173,9 +173,9 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + stddev=(2.0 / (fan_in + fan_out))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = activation_fn(top_mlp_input) @@ -251,7 +251,7 @@ def __call__(self, x, train): mlp_bottom_dims[0], kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), + stddev=1.0 / mlp_bottom_dims[0]**0.5), )(bot_mlp_input) bot_mlp_input = activation_fn(bot_mlp_input) @@ -259,13 +259,13 @@ def __call__(self, x, train): x = nn.Dense( dense_dim, kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), )(bot_mlp_input) bot_mlp_input += activation_fn(x) base_init_fn = jnn.initializers.uniform(scale=1.0) if self.embedding_init_multiplier is None: - embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) + embedding_init_multiplier = 1 / self.vocab_size**0.5 else: embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. @@ -289,9 +289,9 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), + stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0]))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( + stddev=(1.0 / mlp_top_dims[0])**0.5))( top_mlp_input) top_mlp_input = activation_fn(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: @@ -299,9 +299,9 @@ def scaled_init(key, shape, dtype=jnp.float_): x = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + stddev=(2.0 / (fan_in + fan_out))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( top_mlp_input) x = activation_fn(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: @@ -313,9 +313,9 @@ def scaled_init(key, shape, dtype=jnp.float_): logits = nn.Dense( 1, kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), + stddev=(2.0 / (mlp_top_dims[-2] + 1))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0)))(top_mlp_input) + stddev=1.0))(top_mlp_input) return logits