Skip to content

Commit

Permalink
Report forward and backward pass FLOPs of modules and submodules in `…
Browse files Browse the repository at this point in the history
…linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.

Integrated improved argument splitting by @cgarciae in google/flax#3350

Co-authored-by: Cristian Garcia <[email protected]>
PiperOrigin-RevId: 540391111
  • Loading branch information
2 people authored and copybara-github committed Sep 20, 2023
1 parent b0fa629 commit 344cae7
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions init2winit/model_lib/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __call__(self, x, train):
bot_mlp_input = nn.Dense(
dense_dim,
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
)(bot_mlp_input)
bot_mlp_input = activation_fn(bot_mlp_input)
bot_mlp_input = normalizer_layer()(bot_mlp_input)
Expand All @@ -143,7 +143,7 @@ def __call__(self, x, train):

base_init_fn = jnn.initializers.uniform(scale=1.0)
if self.embedding_init_multiplier is None:
embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size)
embedding_init_multiplier = 1 / self.vocab_size**0.5
else:
embedding_init_multiplier = self.embedding_init_multiplier
# Embedding table init and lookup for a single unified table.
Expand Down Expand Up @@ -173,9 +173,9 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.Dense(
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
stddev=(2.0 / (fan_in + fan_out))**0.5),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))(
top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = activation_fn(top_mlp_input)
Expand Down Expand Up @@ -251,21 +251,21 @@ def __call__(self, x, train):
mlp_bottom_dims[0],
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])),
stddev=1.0 / mlp_bottom_dims[0]**0.5),
)(bot_mlp_input)
bot_mlp_input = activation_fn(bot_mlp_input)

for dense_dim in mlp_bottom_dims[1:]:
x = nn.Dense(
dense_dim,
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
)(bot_mlp_input)
bot_mlp_input += activation_fn(x)

base_init_fn = jnn.initializers.uniform(scale=1.0)
if self.embedding_init_multiplier is None:
embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size)
embedding_init_multiplier = 1 / self.vocab_size**0.5
else:
embedding_init_multiplier = self.embedding_init_multiplier
# Embedding table init and lookup for a single unified table.
Expand All @@ -289,19 +289,19 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.Dense(
mlp_top_dims[0],
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0]))**0.5),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
stddev=(1.0 / mlp_top_dims[0])**0.5))(
top_mlp_input)
top_mlp_input = activation_fn(top_mlp_input)
for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
fan_in = mlp_top_dims[layer_idx - 1]
x = nn.Dense(
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
stddev=(2.0 / (fan_in + fan_out))**0.5),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))(
top_mlp_input)
x = activation_fn(x)
if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2:
Expand All @@ -313,9 +313,9 @@ def scaled_init(key, shape, dtype=jnp.float_):
logits = nn.Dense(
1,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
stddev=(2.0 / (mlp_top_dims[-2] + 1))**0.5),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0)))(top_mlp_input)
stddev=1.0))(top_mlp_input)
return logits


Expand Down

0 comments on commit 344cae7

Please sign in to comment.