diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d..3f8b8dd654 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -85,6 +85,7 @@ class ModelArgs: n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer hidden_dim: Optional[int] = None + head_dim: Optional[int] = None # Optional customized head_dim multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 @@ -142,6 +143,9 @@ def __post_init__(self): hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) self.hidden_dim = find_multiple(hidden_dim, multiple_of) + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads + class KVCache(nn.Module): def __init__( @@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // self.n_heads + self.head_dim = args.head_dim self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim @@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int): ) self.SDPA = SDPA( kv_cache=self.kv_cache, - dim=self.dim, + dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, max_seq_len=self.max_seq_len, @@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim - self.head_dim = args.dim // args.n_heads + self.head_dim = args.head_dim self.attention = Attention(args, layer_id) if args.moe: self.block_sparse_moe = MOEFeedForward(args) @@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs): precompute_freqs_cis, use_scaled=params.use_scaled_rope ) freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, + params.head_dim, ( params.max_seq_len # Normal llama2. if params.ffn_dim_multiplier is None