Skip to content

Commit

Permalink
Single location to update optional args for all attentions
Browse files Browse the repository at this point in the history
Differential Revision: D68988021

Pull Request resolved: pytorch#8128
  • Loading branch information
iseeyuan authored Feb 1, 2025
1 parent e92bb7a commit a5c7609
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any, Dict, Optional, Tuple, Type, TypedDict

import torch
import torch.nn as nn
Expand All @@ -8,6 +8,15 @@
from executorch.examples.models.llama.rope import Rope


class ForwardOptions(TypedDict, total=False):
"""Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""

mask: Optional[torch.Tensor]
input_pos: Optional[torch.Tensor]
in_cache_state: Optional[Any]
out_cache_state: Optional[Any]


class Attention(nn.Module, ABC):
"""Abstract base class for attention mechanisms with unified interface."""

Expand All @@ -17,19 +26,14 @@ def forward(
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
in_cache_state: Optional[Any] = None,
out_cache_state: Optional[Any] = None,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
"""Forward pass for attention mechanism.
Args:
x: Input tensor of shape (batch_size, seq_len, dim)
freqs_cos, freqs_sin: Rotary position embedding frequencies
mask: Optional attention mask
input_pos: Positions for KV cache updates
in_cache_state/out_cache_state: Cache states
ForwardOptions: grouped optional args
Returns:
Tuple of (output tensor, updated cache state)
Expand Down Expand Up @@ -209,11 +213,9 @@ def forward(
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
in_cache_state: Optional[Any] = None,
out_cache_state: Optional[Any] = None,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
input_pos = kwargs.get("input_pos")
bsz, seqlen, _ = x.shape

# QKV
Expand Down

0 comments on commit a5c7609

Please sign in to comment.