Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Ranran hide a2a #1029

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ decoder_block: "llama2" # which style of DecoderBlock to use.
# base_mlp_dim, base_num_decoder_layers and/or head_dim.
weight_dtype: float32
global_parameter_scale: 1
base_emb_dim: 2048
base_emb_dim: 6144
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 7168
base_num_decoder_layers: 16
base_mlp_dim: 36864
base_num_decoder_layers: 2
head_dim: 128
mlp_activations: ["silu", "linear"]
dropout_rate: 0.0
Expand All @@ -124,11 +124,15 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.

# mixture of experts (moe)
num_experts: 1
num_experts: 64
num_experts_per_tok: 1
megablox: True
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
num_moe_a2a_chunks: 1 # Number of chunks used for MoE FF layeres to pipeline and add the A2A.
# We can potentially hide (chunks - 1) / chunk fraction of the a2a, at the cost of
# each matmul being a factor of chunk smaller - which may make the matmuls less efficient.
# You should use --xla_tpu_enable_async_all_to_all in conjunction with num_moe_a2a_chunks > 1

# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats.
Expand Down Expand Up @@ -208,7 +212,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']
mesh_axes: ['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -248,7 +252,7 @@ logical_axis_rules: [
['exp', 'expert'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']]
data_sharding: [['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
31 changes: 31 additions & 0 deletions MaxText/configs/models/custom-moe-multi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for custom_moe

base_emb_dim: 8192
base_num_query_heads: 112
base_num_kv_heads: 8
base_mlp_dim: 32768
base_num_decoder_layers: 4
head_dim: 256
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 64
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mistral"
24 changes: 24 additions & 0 deletions MaxText/configs/models/custom-moe-single.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for custom_moe


mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
rope_max_timescale: 1_000_000
decoder_block: "mistral"
93 changes: 93 additions & 0 deletions MaxText/example_hide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import jax
from jax import numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
import datetime
import jax
import random
import string
import os
from jax.experimental import shard_map
from jax.experimental.compilation_cache import compilation_cache
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"


#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution
def simple_timeit(f, *args, tries=10, task=None):
"""Simple utility to time a function for multiple runs"""
assert task is not None

trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
trace_dir = f"gs://mattdavidow-br/{trace_name}"

outcomes_ms = []
jax.block_until_ready(f(*args)) # warm it up!
jax.profiler.start_trace(trace_dir)

for _ in range(tries):
s = datetime.datetime.now()
jax.block_until_ready(f(*args))
e = datetime.datetime.now()
outcomes_ms.append(1000 * (e - s).total_seconds())
jax.profiler.stop_trace()

average_time_ms = sum(outcomes_ms) / len(outcomes_ms)
print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}")
return average_time_ms


# Baseline non-overlapped implementation to compare against
# In some ideal world compiler comes up with an overlapped solution even with naive code
def blocking_a2a(input_activations, weights):
input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X
return jnp.einsum("BXE,XEM -> BXM", input_activations, weights)

# Necessary explicit communication (use shard map)
def a2a(input_chunk):
return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True)

# Desired overlapped implementaion
def overlap_a2a(input_activations, weights):
num_chunks = 4
chunk_size = EMBED // num_chunks

partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP))
partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model')))
for i in range(num_chunks):
chunk_start = chunk_size * i

input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2)
#input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X
input_chunk = shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk)

weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1)

partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk)
return partial_sum

def create_inputs():
input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16)
input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model')))

weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16)
weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model')))
return input_activations, weights

BATCH_PER_EXP = 2048
EMBED = 4096
MLP = 8192
EXP = 4

global mesh
data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4
ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism]
devices_array = mesh_utils.create_device_mesh(ici_parallelism)
mesh = Mesh(devices_array, ["data", "model", "expert"])

input_activations, weights = jax.jit(create_inputs)()

jit_overlap_a2a = jax.jit(overlap_a2a)
simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a")

# jit_blocking_a2a = jax.jit(blocking_a2a)
# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a")
104 changes: 104 additions & 0 deletions MaxText/hide_ff2_a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import jax
from jax import numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
import datetime
import jax
import random
import string
import os
from jax.experimental import shard_map
from jax.experimental.compilation_cache import compilation_cache
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"


#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution
def simple_timeit(f, *args, tries=10, task=None):
"""Simple utility to time a function for multiple runs"""
assert task is not None

trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
trace_dir = f"gs://mattdavidow-br/{trace_name}"

outcomes_ms = []
jax.block_until_ready(f(*args)) # warm it up!
jax.profiler.start_trace(trace_dir)

for _ in range(tries):
s = datetime.datetime.now()
jax.block_until_ready(f(*args))
e = datetime.datetime.now()
outcomes_ms.append(1000 * (e - s).total_seconds())
jax.profiler.stop_trace()

average_time_ms = sum(outcomes_ms) / len(outcomes_ms)
print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}")
return average_time_ms


# Baseline non-overlapped implementation to compare against
# In some ideal world compiler comes up with an overlapped solution even with naive code
def blocking_a2a(input_activations, weights):

outputs = jnp.einsum("BXM,XEM -> BXE", input_activations, weights)
outputs = jax.lax.with_sharding_constraint(outputs, NamedSharding(mesh, P('expert', None, 'model'))) #A2A B,EXP/X -> B/X,EXP
return outputs

# Necessary explicit communication (use shard map)
def a2a(input_chunk):
return jax.lax.all_to_all(input_chunk, 'expert', 0, 1, tiled=True)

# Desired overlapped implementaion
def overlap_a2a(input_activations, weights):
num_chunks = 4
chunk_size = EMBED // num_chunks

ff_output_post_a2a = jnp.zeros((BATCH_PER_EXP, EXP, EMBED), dtype=input_activations.dtype)
# After a2a batch is sharded by expert, expert dim is unsharded
ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model')))
for i in range(num_chunks):
chunk_start = chunk_size * i

weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1)
result_chunk_before_a2a = jnp.einsum("BXM,XEM -> BXE", input_activations, weight_chunk)

result_chunk = shard_map.shard_map(a2a, mesh, in_specs=P(None, 'expert', 'model'), out_specs=P('expert', None, 'model'))(result_chunk_before_a2a)
ff_output_post_a2a = jax.lax.dynamic_update_slice(ff_output_post_a2a, result_chunk, (0,0,chunk_start))
return result_chunk


def create_inputs():
input_activations = jax.random.normal(jax.random.PRNGKey(0), (BATCH_PER_EXP, EXP, MLP), dtype=jnp.bfloat16)
# Inputs start out expert sharded
input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P(None, 'expert','model')))

weights = jax.random.normal(jax.random.PRNGKey(1), (EXP, EMBED, MLP), dtype=jnp.bfloat16)
weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model')))
return input_activations, weights

BATCH_PER_EXP = 16384
EMBED = 4096
MLP = 8192
EXP = 4

global mesh
expert_parallelism, data_parallelism, model_parallelism, = 4, 1, 1
ici_parallelism = [expert_parallelism, data_parallelism, model_parallelism]
devices_array = mesh_utils.create_device_mesh(ici_parallelism)
mesh = Mesh(devices_array, ["expert", "data", "model"])

input_activations, weights = jax.jit(create_inputs)()

# correctness test
# overlapped_results = jax.jit(overlap_a2a)(input_activations, weights)
# blocking_results = jax.jit(blocking_a2a)(input_activations, weights)
# # assert overlapped_results and blocking_results are close
# assert jnp.allclose(overlapped_results, blocking_results, rtol=1e-3, atol=1e-2)

# Profile overlap solution
jit_overlap_a2a = jax.jit(overlap_a2a)
simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a")

# Profile blocking solution
# jit_blocking_a2a = jax.jit(blocking_a2a)
# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a")
Loading
Loading