Skip to content

Commit

Permalink
buckify eval_llama (pytorch#5437)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#5437

This PR buckify `eval_llama`. This is useful when we need to run eval using buck.

Reviewed By: mergennachin

Differential Revision: D62897016

fbshipit-source-id: 59cc64eaa3b29f707b9aa7d3ac2568a16c2743c9
  • Loading branch information
Lunwen He authored and facebook-github-bot committed Sep 30, 2024
1 parent f0662bb commit b60fa71
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 19 deletions.
40 changes: 40 additions & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,45 @@ runtime.python_library(
"//executorch/util:python_profiler",
"fbsource//third-party/pypi/coremltools:coremltools",
"fbsource//third-party/pypi/sentencepiece:sentencepiece",
"//pytorch/ao:torchao",
],
)

runtime.python_binary(
name = "eval_llama",
main_function = "executorch.examples.models.llama2.eval_llama.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
":eval_library",
"//caffe2:torch",
],
)

runtime.python_library(
name = "eval_library",
srcs = [
"eval_llama.py",
"eval_llama_lib.py",
"evaluate/eager_eval.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama2",
visibility = [
"//bento/...",
"//bento_kernels/...",
"//executorch/examples/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"fbsource//third-party/pypi/lm-eval:lm-eval",
"fbsource//third-party/pypi/tiktoken:tiktoken",
":export_library",
"//executorch/examples/models/llama2/tokenizer:tiktoken_py",
"//executorch/extension/llm/export:export_lib",
"//executorch/extension/llm/tokenizer:tokenizer_py_lib",
"//executorch/extension/pybindings:portable_lib",
],
)
2 changes: 1 addition & 1 deletion examples/models/llama2/eval_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main() -> None:
args = parser.parse_args()
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True
eval_llama(modelname, args)
eval_llama(modelname, args) # pyre-ignore


if __name__ == "__main__":
Expand Down
31 changes: 16 additions & 15 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from typing import Optional, Union

import torch
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
from executorch.examples.models.llama2.export_llama_lib import (
get_quantizer_and_quant_params,
)
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken

from executorch.extension.llm.export import LLMEdgeManager
from executorch.extension.llm.export.builder import LLMEdgeManager
from executorch.extension.llm.tokenizer.tokenizer import (
Tokenizer as SentencePieceTokenizer,
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.api.model import LM

from .evaluate.eager_eval import EagerEvalWrapper, evaluate_model

from .export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
):
super().__init__(None, tokenizer, max_seq_length)
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
self._model = model # Expects model to be path to a .pte file

from executorch.extension.pybindings.portable_lib import _load_for_executorch
Expand All @@ -106,7 +107,7 @@ def __init__(
from executorch.kernels import quantized # noqa

self._et_model = _load_for_executorch(self._model)
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0]
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore

def _model_call(self, inps):
# Given inps (tokens), return the logits from a single forward call
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
tokenizer_bin: str,
max_seq_length: Optional[int] = None,
):
super().__init__(None, tokenizer, max_seq_length)
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
self._model = model
self._tokenizer_bin = tokenizer_bin

Expand All @@ -165,17 +166,17 @@ def gen_eval_wrapper(
Returns:
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
"""
tokenizer = get_tokenizer(args.tokenizer_path)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore

# ExecuTorch Binary Evaluation
if (model := args.pte) is not None:
if (tokenizer_bin := args.tokenizer_bin) is not None:
if (model := args.pte) is not None: # pyre-ignore
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
return ETRunnerEvalWrapper(
model=model,
tokenizer=tokenizer,
tokenizer_bin=tokenizer_bin,
max_seq_length=args.max_seq_length,
max_seq_length=args.max_seq_length, # pyre-ignore
)

# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
Expand All @@ -194,16 +195,16 @@ def gen_eval_wrapper(
if len(quantizers) != 0:
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)
model = (
manager.pre_autograd_graph_module.to(device="cuda")
manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore
if torch.cuda.is_available()
else manager.pre_autograd_graph_module.to(device="cpu")
)
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
enable_dynamic_shape=args.enable_dynamic_shape,
use_kv_cache=args.use_kv_cache, # pyre-ignore
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
Expand All @@ -221,7 +222,7 @@ def gen_eval_wrapper(
# that is not available in this eval_llama. We save the checkpoint
# here for consistency with eval_llama. The accuracy results we
# get from eval_llama can be used as a reference to other evaluations.
if args.output_eager_checkpoint_file is not None:
if args.output_eager_checkpoint_file is not None: # pyre-ignore
torch.save(model, args.output_eager_checkpoint_file)

return EagerEvalWrapper(
Expand Down Expand Up @@ -282,8 +283,8 @@ def eval_llama(
# Evaluate the model
eval_results = evaluate_model(
eval_wrapper,
args.tasks,
args.limit,
args.tasks, # pyre-ignore
args.limit, # pyre-ignore
)

for task, res in eval_results["results"].items():
Expand Down
6 changes: 4 additions & 2 deletions examples/models/llama2/evaluate/eager_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def batch_size(self):
def device(self):
return self._device

def tok_encode(self, string: str, **kwargs):
def tok_encode(self, string: str, **kwargs): # pyre-ignore
tokens = self._tokenizer.encode(string, bos=True, eos=False)
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
# encoded is a pytorch tensor, but some internal logic in the
Expand Down Expand Up @@ -111,7 +111,9 @@ def evaluate_model(

if "hendrycks_test" in tasks:
tasks.remove("hendrycks_test")
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
tasks += list(
lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore
)
task_dict = get_task_dict(tasks)

eval_results = evaluate(
Expand Down
4 changes: 3 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def pt2e_calibrate(
from executorch.examples.models.llama2.eval_llama_lib import (
GraphModuleEvalWrapper,
)
from executorch.examples.models.llama2.evaluate import evaluate_model
from executorch.examples.models.llama2.evaluate import ( # pyre-ignore[21]
evaluate_model,
)
except ImportError:
raise ImportError(
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
Expand Down

0 comments on commit b60fa71

Please sign in to comment.