Skip to content

Commit

Permalink
Make dynamic shape based export selectable (pytorch#4100)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4100

Most backends like mps, qnn, coreml dont support dynamic shape. Particularly
for mps, query to .item(), data dependent dynamic shape, results in lowering
error. This error is likely fixable in the delegate however not clear if the
delegate is meant to support dynamic shape or not.

//pre-existing lint error
bypass-github-export-checks

Thus making dynamic shape specific change selectable for now.
ghstack-source-id: 232298483

Reviewed By: cccclai

Differential Revision: D59236591

fbshipit-source-id: 2870148373f2f9f0ee9d4646355ac0664d2c2f68
  • Loading branch information
kimishpatel authored and facebook-github-bot committed Jul 2, 2024
1 parent 5dc3b2b commit e74aa2e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
93 changes: 58 additions & 35 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ModelArgs:
use_sdpa_with_kv_cache_op: bool = (
False # Use custom sdpa op that updates kv cache in-place
)
enable_dynamic_shape: bool = False # export model with dynamic shape support
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
)
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
Expand All @@ -198,6 +200,7 @@ def __init__(
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)

self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)
Expand All @@ -209,23 +212,31 @@ def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
seq_length = k_val.size(2)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# We use .narrow() here to make the compiler happy
# pyre-ignore: Incompatible parameter type [6]
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)

narrowed_k.copy_(k_val)
narrowed_v.copy_(v_val)
return self.k_cache, self.v_cache
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
seq_length = k_val.size(2)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# We use .narrow() here to make the compiler happy
# pyre-ignore: Incompatible parameter type [6]
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)

narrowed_k.copy_(k_val)
narrowed_v.copy_(v_val)
return self.k_cache, self.v_cache
else:
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
Expand All @@ -236,13 +247,15 @@ def __init__(
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_seq_len = max_seq_len
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
self,
Expand All @@ -259,12 +272,15 @@ def forward(
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
else:
attn_mask = mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down Expand Up @@ -312,13 +328,15 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

def forward(
Expand Down Expand Up @@ -496,18 +514,23 @@ def forward(
input_pos is not None
), "input_pos must be provided when use_kv_cache is True"

# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
# Setting this value to 32 for no particular reason.
# It is mainly to make export happy as the resulting
# asserts are ignored anyway.
# We really need unbounded start_pos
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
if self.params.enable_dynamic_shape:
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
else:
# When not using dynamic shape, use of the .item results in
# symints, due to querying the data from tensor.
# this path avoids that for mps backend, although probably mps backend
# can support dynamic shape?
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]

else:
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
freqs_cos = self.freqs_cos[:seqlen]
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(self, **kwargs):
max_batch_size=max_batch_size,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
)
if kwargs.get("fairseq2", False):
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama2/tests/test_simple_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def test_simple_sdpa(self):
n_heads=n_heads,
head_dim=head_dim,
transpose_cache=True,
enable_dynamic_shape=False,
)
sdpa = SDPA(
kv_cache=copy.deepcopy(kv_cache),
dim=dim,
head_dim=head_dim,
n_rep=n_rep,
max_seq_len=max_seq_length,
enable_dynamic_shape=False,
)
input_pos = torch.tensor([0])
query = torch.randn(1, 1, n_local_heads, head_dim)
Expand Down

0 comments on commit e74aa2e

Please sign in to comment.