diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 8bebfb8e9b..5f1d8fb666 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -88,6 +88,7 @@ class ModelArgs: use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place ) + enable_dynamic_shape: bool = False # export model with dynamic shape support rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. ) @@ -188,6 +189,7 @@ def __init__( n_heads: int, head_dim: int, transpose_cache: bool, + enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() @@ -198,6 +200,7 @@ def __init__( cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) self.transpose_cache = transpose_cache + self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") ) @@ -209,23 +212,31 @@ def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) - seq_length = k_val.size(2) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_length) + seq_length = k_val.size(2) + # Replace the entry in the cache for this token + # The following lines are equivalent to: + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # We use .narrow() here to make the compiler happy + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache + else: + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out class SDPA(nn.Module): @@ -236,6 +247,7 @@ def __init__( head_dim: int, n_rep: int, max_seq_len: int, + enable_dynamic_shape: bool, ): super().__init__() self.kv_cache = kv_cache @@ -243,6 +255,7 @@ def __init__( self.head_dim = head_dim self.n_rep = n_rep self.max_seq_len = max_seq_len + self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -259,12 +272,15 @@ def forward( v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = mask.narrow(0, start_pos, seq_length) + else: + attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -312,6 +328,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_kv_heads, self.head_dim, not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v + args.enable_dynamic_shape, ) self.SDPA = SDPA( kv_cache=self.kv_cache, @@ -319,6 +336,7 @@ def __init__(self, args: ModelArgs, layer_id: int): head_dim=self.head_dim, n_rep=self.n_rep, max_seq_len=self.max_seq_len, + enable_dynamic_shape=args.enable_dynamic_shape, ) def forward( @@ -496,18 +514,23 @@ def forward( input_pos is not None ), "input_pos must be provided when use_kv_cache is True" - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - # Setting this value to 32 for no particular reason. - # It is mainly to make export happy as the resulting - # asserts are ignored anyway. - # We really need unbounded start_pos - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + else: assert input_pos is None, "input_pos is unused when use_kv_cache is False" freqs_cos = self.freqs_cos[:seqlen] diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 197de2289b..e3eb1717be 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -162,6 +162,7 @@ def __init__(self, **kwargs): max_batch_size=max_batch_size, use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + enable_dynamic_shape=self.enable_dynamic_shape, **params, ) if kwargs.get("fairseq2", False): diff --git a/examples/models/llama2/tests/test_simple_sdpa.py b/examples/models/llama2/tests/test_simple_sdpa.py index 61f14e58dc..264ed3dde3 100644 --- a/examples/models/llama2/tests/test_simple_sdpa.py +++ b/examples/models/llama2/tests/test_simple_sdpa.py @@ -30,6 +30,7 @@ def test_simple_sdpa(self): n_heads=n_heads, head_dim=head_dim, transpose_cache=True, + enable_dynamic_shape=False, ) sdpa = SDPA( kv_cache=copy.deepcopy(kv_cache), @@ -37,6 +38,7 @@ def test_simple_sdpa(self): head_dim=head_dim, n_rep=n_rep, max_seq_len=max_seq_length, + enable_dynamic_shape=False, ) input_pos = torch.tensor([0]) query = torch.randn(1, 1, n_local_heads, head_dim)