From 5d58203f4fb93f773553c55dbb32c8d19cb8d801 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Jul 2024 23:19:01 -0700 Subject: [PATCH] Disable sdpa_with_kv_cache for now (#4319) Summary: For some reason `sdpa_with_kv_cache` custom op gives wrong result after prefill. This PR disables it and add a couple of eager mode unit tests. Pull Request resolved: https://github.com/pytorch/executorch/pull/4319 Reviewed By: helunwencser Differential Revision: D59988495 Pulled By: larryliu0820 fbshipit-source-id: cce45791f6f492fe3ddc39ef2ad4401ea3dfc407 --- .github/workflows/pull.yml | 3 + examples/models/llava/install_requirements.sh | 29 +++++--- examples/models/llava/main.py | 66 ------------------ examples/models/llava/model.py | 5 +- examples/models/llava/test/test_llava.py | 68 +++++++++++++++++++ pytest.ini | 1 + 6 files changed, 96 insertions(+), 76 deletions(-) delete mode 100644 examples/models/llava/main.py create mode 100644 examples/models/llava/test/test_llava.py diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index a05b0833e1..36099ca651 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -218,6 +218,9 @@ jobs: exit 1 fi + # run python unittest + python -m unittest examples.models.llava.test.test_llava + test-quantized-aot-lib-linux: name: test-quantized-aot-lib-linux uses: pytorch/test-infra/.github/workflows/linux_job.yml@main diff --git a/examples/models/llava/install_requirements.sh b/examples/models/llava/install_requirements.sh index d6d6f3e32b..68923c2dad 100644 --- a/examples/models/llava/install_requirements.sh +++ b/examples/models/llava/install_requirements.sh @@ -6,9 +6,23 @@ # LICENSE file in the root directory of this source tree. set -x +OS=$(uname) -# install llava from the submodule -pip install --force-reinstall -e examples/third-party/LLaVA +# install llava from the submodule. We can't do pip install llava because it is packaged incorrectly. +if [[ $OS != "Darwin" ]]; +then + #This doesn't work for macos, on python 3.12, because torch 2.1.2 is missing. + pip install --force-reinstall -e examples/third-party/LLaVA +else + # manually install dependencies + pip install tokenizers==0.15.1 sentencepiece==0.1.99 \ + shortuuid accelerate==0.21.0 peft \ + pydantic markdown2[all] scikit-learn==1.2.2 \ + requests httpx==0.24.0 uvicorn fastapi \ + einops==0.6.1 einops-exts==0.0.4 timm==0.6.13 + + pip install --force-reinstall -e examples/third-party/LLaVA --no-deps +fi # not included in the pip install package, but needed in llava pip install protobuf @@ -17,15 +31,14 @@ pip install protobuf # Reinstall bitsandbytes to make it compatible. pip install bitsandbytes -I -# numpy needs to be pin to 1.24. 1.26.4 will error out -pip install numpy==1.24 - # The deps of llava can have different versions than deps of ExecuTorch. # For example, torch version required from llava is older than ExecuTorch. # To make both work, recover ExecuTorch's original dependencies by rerunning -# the install_requirements.sh. +# the install_requirements.sh. Notice this won't install executorch. bash -x ./install_requirements.sh --pybind xnnpack -# Newer transformer will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position' +# Newer transformer (4.38) will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position' pip install timm==0.6.13 -pip install transformers==4.38.2 +pip install transformers==4.37.2 + +pip list diff --git a/examples/models/llava/main.py b/examples/models/llava/main.py deleted file mode 100644 index f531a41852..0000000000 --- a/examples/models/llava/main.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging - -import torch - -from model import LlavaModel - - -def main(): - - llava_model = LlavaModel() - llava = llava_model.get_eager_model() - - prompt_before_image, resized, prompt_after_image = llava_model.get_example_inputs() - logging.info(f"Prompt: {llava_model.prompt}") - preprocessed = llava.image_preprocess(resized) - with torch.inference_mode(): - output_ids = llava_model.model.generate( - llava_model.input_ids, - images=preprocessed, - image_sizes=[preprocessed.size], - do_sample=False, - num_beams=1, - max_new_tokens=10, - use_cache=True, - ) - - outputs = llava_model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[ - 0 - ].strip() - logging.info(f"Reference output: {outputs}") - - # comparing with llava result - # prefill_logits = llava.prefill(prompt_before_image, resized, prompt_after_image) - # prefill_logits_ref = llava.prefill_ref(*inputs)[0] - # print(f"Prefill logits all close? {torch.allclose(prefill_logits, prefill_logits_ref, atol=1e-3)}") - - # prefill_logits = llava.prefill(*inputs) - # context_len = prefill_logits.shape[1] - # print(prefill_logits) - # # first token - # new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()] - # # print(tokenizer.decode(new_tokens)) - # for i in range(llava_model.args.max_new_tokens): - # print(i, llava_model.tokenizer.decode(new_tokens[i])) - # logits = llava.forward( - # torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]) - # ) - # new_tokens.append(torch.argmax(logits[-1, :])) - prefill_logits = llava.prefill(prompt_before_image, resized, prompt_after_image) - context_len = prefill_logits.shape[1] - logging.info(prefill_logits) - new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()] - i = 0 - logging.info(i, llava_model.tokenizer.decode(new_tokens[i])) - logits = llava.step(torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])) - logging.info(logits) - - -if __name__ == "__main__": - main() diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 5a5b5716f7..31270b9042 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -17,6 +17,7 @@ import torch import torchvision from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer + from executorch.examples.models.llama2.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) @@ -72,7 +73,7 @@ def __init__( max_batch_size=1, # doesn't work with default batch size 32 ffn_dim_multiplier=1, # TODO: a hack to make rotary embedding happy enable_dynamic_shape=True, # allow parallel prefill - use_sdpa_with_kv_cache_op=True, + use_sdpa_with_kv_cache_op=True, # use sdpa_with_kv_cache op use_hf_rope=True, ) self.embed_tokens = nn.Embedding( @@ -81,7 +82,7 @@ def __init__( self.model_.config.pad_token_id, ) self.text_model = Transformer(self.text_model_args) - # use custom op for SDPA + # use custom op for SDPA. self.text_model = replace_sdpa_with_custom_op(self.text_model) # load state dict self.text_model.load_state_dict( diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py new file mode 100644 index 0000000000..ce0a527bc9 --- /dev/null +++ b/examples/models/llava/test/test_llava.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +import torch + +from executorch.examples.models.llava.model import LlavaModel + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class TestLlava(unittest.TestCase): + def setUp(self): + self.llava_model = LlavaModel() + self.llava = self.llava_model.get_eager_model() + self.prompt_before_image, self.resized, self.prompt_after_image = ( + self.llava_model.get_inputs_for_prefill() + ) + + def test_prefill_logits(self): + prefill_logits = self.llava.prefill( + self.prompt_before_image, self.resized, self.prompt_after_image + ) + prefill_logits_ref = self.llava.prefill_ref( + self.prompt_before_image, self.resized, self.prompt_after_image + )[0] + self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2)) + + def test_generated_output(self): + # source of truth, using HF llava + preprocessed = self.llava.image_preprocess(self.resized) + with torch.inference_mode(): + output_ids = self.llava_model.model.generate( + self.llava_model.input_ids, + images=preprocessed, + image_sizes=[preprocessed.size], + do_sample=False, + num_beams=1, + max_new_tokens=5, + use_cache=True, + ) + + ref_outputs = self.llava_model.tokenizer.batch_decode( + output_ids, skip_special_tokens=True + )[0].strip() + + # being tested, using llama_transformer + prefill_logits = self.llava.prefill( + self.prompt_before_image, self.resized, self.prompt_after_image + ) + context_len = prefill_logits.shape[1] + new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()] + for i in range(4): + logits = self.llava.step( + torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]) + ) + new_tokens.append(torch.argmax(logits[-1, :]).item()) + + outputs = self.llava_model.tokenizer.batch_decode( + torch.tensor([new_tokens]), skip_special_tokens=True + )[0].strip() + self.assertEqual(outputs, ref_outputs) diff --git a/pytest.ini b/pytest.ini index 8443e48e5c..5ed1780e61 100644 --- a/pytest.ini +++ b/pytest.ini @@ -17,6 +17,7 @@ addopts = sdk/ # examples examples/models/llama2/tests + # examples/models/llava/test TODO: enable this # exir exir/_serialize/test exir/backend/test