Skip to content

Commit

Permalink
allow customized head_dim (pytorch#7065)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502
ghstack-source-id: 255340016

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

Co-authored-by: Lunwen He <[email protected]>
  • Loading branch information
pytorchbot and helunwencser authored Nov 25, 2024
1 parent 52fa043 commit a1f668d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ModelArgs:
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
hidden_dim: Optional[int] = None
head_dim: Optional[int] = None # Optional customized head_dim
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
Expand Down Expand Up @@ -142,6 +143,9 @@ def __post_init__(self):
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
self.hidden_dim = find_multiple(hidden_dim, multiple_of)

if self.head_dim is None:
self.head_dim = self.dim // self.n_heads


class KVCache(nn.Module):
def __init__(
Expand Down Expand Up @@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.n_local_heads = self.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // self.n_heads
self.head_dim = args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
Expand Down Expand Up @@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.dim,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
Expand Down Expand Up @@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.head_dim = args.head_dim
self.attention = Attention(args, layer_id)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
Expand Down Expand Up @@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs):
precompute_freqs_cis, use_scaled=params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
params.dim // params.n_heads,
params.head_dim,
(
params.max_seq_len # Normal llama2.
if params.ffn_dim_multiplier is None
Expand Down

0 comments on commit a1f668d

Please sign in to comment.