diff --git a/examples/models/llama2/eval_llama.py b/examples/models/llama2/eval_llama.py index 0495c76bbf..4daeaf7afa 100644 --- a/examples/models/llama2/eval_llama.py +++ b/examples/models/llama2/eval_llama.py @@ -22,6 +22,8 @@ def main() -> None: modelname = "llama2" parser = build_args_parser() args = parser.parse_args() + # Overrides this arg, because evaluation requires full logits. + args.generate_full_logits = True eval_llama(modelname, args) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index c22c0a3c3c..8ff5d3aa26 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -296,6 +296,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Generate the ETRecord debug artifact.", ) + parser.add_argument( + "--generate_full_logits", + action="store_true", + required=False, + default=True, + help="Generate logits for all inputs.", + ) return parser @@ -405,6 +412,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: params_path=params_path, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, + generate_full_logits=args.generate_full_logits, weight_type=weight_type, enable_dynamic_shape=args.enable_dynamic_shape, verbose=args.verbose, @@ -590,6 +598,7 @@ def _load_llama_model( params_path: str, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, + generate_full_logits: bool = True, weight_type: WeightType = WeightType.LLAMA, enable_dynamic_shape: bool = False, verbose: bool = False, @@ -616,6 +625,7 @@ def _load_llama_model( params=params_path, use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, + generate_full_logits=generate_full_logits, fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, enable_dynamic_shape=enable_dynamic_shape, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 4ae12b0f64..81b47a3a5d 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -96,6 +96,10 @@ class ModelArgs: use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place ) + # Generate logits for all inputs. When it's True, it would take big memory usage + # at runtime. Enable it only necessary (e.g., use perplexity tools that requires + # logits for all input tokens.) + generate_full_logits: bool = True enable_dynamic_shape: bool = False # export model with dynamic shape support use_hf_rope: bool = False # Use HuggingFace's RoPE implementation rope_theta: Optional[float] = ( @@ -442,6 +446,7 @@ def __init__(self, params: ModelArgs): self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache + self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len if params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis @@ -512,6 +517,10 @@ def forward( input_pos, ) + if not self.generate_full_logits: + # Only the last logit is used for the new generated token + h = h[:, -1, :] + h = self.norm(h) logits = self.output(h) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index fdf0dc707e..b375399f33 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -61,6 +61,7 @@ def __init__(self, **kwargs): self.use_kv_cache = kwargs.get("use_kv_cache", False) self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) + self.generate_full_logits = kwargs.get("generate_full_logits", True) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) self.max_seq_len = kwargs.get("max_seq_len", 128) @@ -145,6 +146,7 @@ def __init__(self, **kwargs): max_batch_size=max_batch_size, use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, **params, )