From e74aa2e27173b2de2a6b0a6c111783dbfe20ecd1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 2 Jul 2024 09:44:27 -0700 Subject: [PATCH] Make dynamic shape based export selectable (#4100) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4100 Most backends like mps, qnn, coreml dont support dynamic shape. Particularly for mps, query to .item(), data dependent dynamic shape, results in lowering error. This error is likely fixable in the delegate however not clear if the delegate is meant to support dynamic shape or not. //pre-existing lint error bypass-github-export-checks Thus making dynamic shape specific change selectable for now. ghstack-source-id: 232298483 Reviewed By: cccclai Differential Revision: D59236591 fbshipit-source-id: 2870148373f2f9f0ee9d4646355ac0664d2c2f68 --- examples/models/llama2/llama_transformer.py | 93 ++++++++++++------- examples/models/llama2/model.py | 1 + .../models/llama2/tests/test_simple_sdpa.py | 2 + 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 8bebfb8e9b..5f1d8fb666 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -88,6 +88,7 @@ class ModelArgs: use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place ) + enable_dynamic_shape: bool = False # export model with dynamic shape support rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. ) @@ -188,6 +189,7 @@ def __init__( n_heads: int, head_dim: int, transpose_cache: bool, + enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() @@ -198,6 +200,7 @@ def __init__( cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) self.transpose_cache = transpose_cache + self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") ) @@ -209,23 +212,31 @@ def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) - seq_length = k_val.size(2) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_length) + seq_length = k_val.size(2) + # Replace the entry in the cache for this token + # The following lines are equivalent to: + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # We use .narrow() here to make the compiler happy + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache + else: + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out class SDPA(nn.Module): @@ -236,6 +247,7 @@ def __init__( head_dim: int, n_rep: int, max_seq_len: int, + enable_dynamic_shape: bool, ): super().__init__() self.kv_cache = kv_cache @@ -243,6 +255,7 @@ def __init__( self.head_dim = head_dim self.n_rep = n_rep self.max_seq_len = max_seq_len + self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -259,12 +272,15 @@ def forward( v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = mask.narrow(0, start_pos, seq_length) + else: + attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -312,6 +328,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_kv_heads, self.head_dim, not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v + args.enable_dynamic_shape, ) self.SDPA = SDPA( kv_cache=self.kv_cache, @@ -319,6 +336,7 @@ def __init__(self, args: ModelArgs, layer_id: int): head_dim=self.head_dim, n_rep=self.n_rep, max_seq_len=self.max_seq_len, + enable_dynamic_shape=args.enable_dynamic_shape, ) def forward( @@ -496,18 +514,23 @@ def forward( input_pos is not None ), "input_pos must be provided when use_kv_cache is True" - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - # Setting this value to 32 for no particular reason. - # It is mainly to make export happy as the resulting - # asserts are ignored anyway. - # We really need unbounded start_pos - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + else: assert input_pos is None, "input_pos is unused when use_kv_cache is False" freqs_cos = self.freqs_cos[:seqlen] diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 197de2289b..e3eb1717be 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -162,6 +162,7 @@ def __init__(self, **kwargs): max_batch_size=max_batch_size, use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + enable_dynamic_shape=self.enable_dynamic_shape, **params, ) if kwargs.get("fairseq2", False): diff --git a/examples/models/llama2/tests/test_simple_sdpa.py b/examples/models/llama2/tests/test_simple_sdpa.py index 61f14e58dc..264ed3dde3 100644 --- a/examples/models/llama2/tests/test_simple_sdpa.py +++ b/examples/models/llama2/tests/test_simple_sdpa.py @@ -30,6 +30,7 @@ def test_simple_sdpa(self): n_heads=n_heads, head_dim=head_dim, transpose_cache=True, + enable_dynamic_shape=False, ) sdpa = SDPA( kv_cache=copy.deepcopy(kv_cache), @@ -37,6 +38,7 @@ def test_simple_sdpa(self): head_dim=head_dim, n_rep=n_rep, max_seq_len=max_seq_length, + enable_dynamic_shape=False, ) input_pos = torch.tensor([0]) query = torch.randn(1, 1, n_local_heads, head_dim)