Skip to content

Commit

Permalink
Merge pull request #1067 from AI-Hypercomputer:less_prefill_array
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700785020
  • Loading branch information
maxtext authors committed Nov 27, 2024
2 parents 3f93a89 + 1d7aa1b commit 7331e13
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 16 deletions.
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file
enable_model_warmup: False

# Stack prefill cache across the layer to reduce the
# Python layer latency.
stack_prefill_result_cache: False

# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
Expand Down
110 changes: 94 additions & 16 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,22 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)

self.prefill_kv_cache_annotations = max_utils.get_prefill_kv_cache_annotations(self.model, self.config, rng2, self._mesh)
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.prefill_kv_cache_annotations
lambda x: jax.sharding.NamedSharding(self._mesh, x),
self.prefill_kv_cache_annotations,
)

if self.config.stack_prefill_result_cache:
# Add extra axis for the axis generated by the stack.
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, jax.sharding.PartitionSpec(None, *x.spec)),
self.prefill_kv_cache_shardings,
)
self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings["decoder"]["layers_0"]

self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, rng2, self._mesh)
self.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations
lambda x: jax.sharding.NamedSharding(self._mesh, x),
self.kv_cache_annotations,
)

if self.model.quant and not self.config.checkpoint_is_quantized:
Expand Down Expand Up @@ -172,12 +182,40 @@ def model_apply(_p, _rng):
params["aqt"] = new_vars["aqt"]
params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"])
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
params,
)
max_utils.save_quantized_checkpoint_if_configured(self.config, params)
self.model.quant.quant_mode = quantizations.get_quant_mode("serve")
return params

def _maybe_stack_prefill_result_cache(self, cache):
"""Stack the caches across the layers."""
if not self.config.stack_prefill_result_cache:
return cache

layer_keys = []
for i in range(self.config.num_decoder_layers):
layer_keys.append(f"layers_{i}")

layer_cache = [cache["decoder"][layer_key] for layer_key in layer_keys]

return jax.tree.map(lambda *c: jnp.stack(c), *layer_cache)

def _maybe_unstack_prefill_result_cache(self, cache):
"""Unstack the caches across the layers."""
if not self.config.stack_prefill_result_cache:
return cache

flat_cache, treedef = jax.tree.flatten(cache)
layer_cache = [jax.tree.unflatten(treedef, flat_cache_vars) for flat_cache_vars in zip(*flat_cache, strict=True)]
res_cache = {"decoder": {}}

for i in range(self.config.num_decoder_layers):
res_cache["decoder"][f"layers_{i}"] = layer_cache[i]

return res_cache

@functools.partial(jax.jit, static_argnums=(0,))
def prefill(
self,
Expand Down Expand Up @@ -231,7 +269,9 @@ def prefill(
next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32)
generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32)
selected_logits = jax.lax.dynamic_slice(
flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2])
flat_logits,
(0, true_length - 1, 0),
(flat_logits.shape[0], 1, flat_logits.shape[2]),
)
selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding)

Expand Down Expand Up @@ -259,9 +299,12 @@ def prefill(
samples_per_slot=1,
)

cache = new_vars["cache"]
cache = self._maybe_stack_prefill_result_cache(cache)

return {
"logits": selected_logits,
"cache": new_vars["cache"],
"cache": cache,
"next_pos": next_pos,
"generated_tokens": generated_tokens,
"tokens": first_generated_token,
Expand Down Expand Up @@ -346,9 +389,17 @@ def insert(
"""Insert into KV cache"""
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)

unboxed_prefix["cache"] = self._maybe_unstack_prefill_result_cache(unboxed_prefix["cache"])

def copy(path, partial_cache, full_cache, annotations):
path_key = path[-1].key
if path_key in ["cache_ar_index", "cached_ar_key", "cached_ar_value", "cached_ar_key_scale", "cached_ar_value_scale"]:
if path_key in [
"cache_ar_index",
"cached_ar_key",
"cached_ar_value",
"cached_ar_key_scale",
"cached_ar_value_scale",
]:
return full_cache # we don't even zero these out because we can mask them out.

batch_idx = -1
Expand Down Expand Up @@ -388,12 +439,18 @@ def copy(path, partial_cache, full_cache, annotations):
raise ValueError(f"We don't have a strategy for inserting {path_key}")

inserted_cache = jax.tree_util.tree_map_with_path(
copy, unboxed_prefix["cache"], decode_state["cache"], self.kv_cache_annotations_named
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)
inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0)
inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(
decode_state["generated_tokens"], unboxed_prefix["generated_tokens"], slot, 0
decode_state["generated_tokens"],
unboxed_prefix["generated_tokens"],
slot,
0,
)
inserted_tokens = jax.lax.dynamic_update_index_in_dim(decode_state["tokens"], unboxed_prefix["tokens"], slot, 0)

Expand Down Expand Up @@ -458,11 +515,26 @@ def init(abstract_params):
mutable=["cache"],
)

next_pos = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
generated_tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
next_pos = jnp.zeros(
(int(self.config.per_device_batch_size * jax.device_count()), 1),
dtype=jnp.int32,
)
generated_tokens = jnp.zeros(
(int(self.config.per_device_batch_size * jax.device_count()), 1),
dtype=jnp.int32,
)
tokens = jnp.zeros(
(int(self.config.per_device_batch_size * jax.device_count()), 1),
dtype=jnp.int32,
)
return {
"logits": jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)),
"logits": jnp.zeros(
(
int(self.config.per_device_batch_size * jax.device_count()),
1,
self.config.vocab_size,
)
),
"cache": cache["cache"],
"next_pos": next_pos,
"generated_tokens": generated_tokens,
Expand All @@ -477,7 +549,8 @@ def init(abstract_params):
mesh_annotations = nn.logical_to_mesh(logical_annotations)

shardings = jax.tree_util.tree_map(
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation),
mesh_annotations,
)

@functools.partial(jax.jit, out_shardings=shardings)
Expand Down Expand Up @@ -519,16 +592,21 @@ def colocated_cpus(self) -> None:
raise NotImplementedError


def set_engine_vars_from_base_engine(engine: engine_api.Engine, base_engine: engine_api.Engine, rng: jax.random.PRNGKey):
def set_engine_vars_from_base_engine(
engine: engine_api.Engine,
base_engine: engine_api.Engine,
rng: jax.random.PRNGKey,
):
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
mesh, and kv cache related vars set.
"""
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
engine.state_mesh_annotations = base_engine.state_mesh_annotations
engine.abstract_params = base_engine.abstract_params
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine._mesh) # pylint: disable=protected-access
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access
engine.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(engine._mesh, x), engine.kv_cache_annotations # pylint: disable=protected-access
lambda x: jax.sharding.NamedSharding(engine.mesh, x),
engine.kv_cache_annotations, # pylint: disable=protected-access
)


Expand Down
62 changes: 62 additions & 0 deletions MaxText/tests/maxengine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Tests for the maxengine """

import jax
from jax import numpy as jnp
import numpy as np
import unittest
import pyconfig
from maxengine import MaxEngine


class MaxEngineTest(unittest.TestCase):
"""Tests for MaxEngine."""

# TODO: add unit test for the MaxEngine.

def test_stack_and_unstack_prefill_cache(self):
pyconfig.initialize(
[None, "configs/base.yml"],
enable_checkpointing=False,
stack_prefill_result_cache=True,
)
config = pyconfig.config
engine = MaxEngine(config, jax.devices())
num_layers = engine.config.num_decoder_layers
input = {
"decoder": {},
}
for i in range(num_layers):
input["decoder"][f"layers_{i}"] = {
"a": jnp.ones((1, 10)),
"b": jnp.ones((1, 9)),
}

expected_stacked = {
"a": jnp.ones((num_layers, 1, 10)),
"b": jnp.ones((num_layers, 1, 9)),
}
got_stacked = engine._maybe_stack_prefill_result_cache(input)
jax.tree.map(np.testing.assert_array_equal, got_stacked, expected_stacked)

got_unstacked = engine._maybe_unstack_prefill_result_cache(got_stacked)
jax.tree.map(np.testing.assert_array_equal, got_unstacked, input)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7331e13

Please sign in to comment.