Skip to content

Commit

Permalink
Disable sdpa_with_kv_cache for now (pytorch#4319)
Browse files Browse the repository at this point in the history
Summary:
For some reason `sdpa_with_kv_cache` custom op gives wrong result after prefill. This PR disables it and add a couple of eager mode unit tests.

Pull Request resolved: pytorch#4319

Reviewed By: helunwencser

Differential Revision: D59988495

Pulled By: larryliu0820

fbshipit-source-id: cce45791f6f492fe3ddc39ef2ad4401ea3dfc407
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Jul 23, 2024
1 parent 6556991 commit 5d58203
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ jobs:
exit 1
fi
# run python unittest
python -m unittest examples.models.llava.test.test_llava
test-quantized-aot-lib-linux:
name: test-quantized-aot-lib-linux
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
29 changes: 21 additions & 8 deletions examples/models/llava/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
# LICENSE file in the root directory of this source tree.

set -x
OS=$(uname)

# install llava from the submodule
pip install --force-reinstall -e examples/third-party/LLaVA
# install llava from the submodule. We can't do pip install llava because it is packaged incorrectly.
if [[ $OS != "Darwin" ]];
then
#This doesn't work for macos, on python 3.12, because torch 2.1.2 is missing.
pip install --force-reinstall -e examples/third-party/LLaVA
else
# manually install dependencies
pip install tokenizers==0.15.1 sentencepiece==0.1.99 \
shortuuid accelerate==0.21.0 peft \
pydantic markdown2[all] scikit-learn==1.2.2 \
requests httpx==0.24.0 uvicorn fastapi \
einops==0.6.1 einops-exts==0.0.4 timm==0.6.13

pip install --force-reinstall -e examples/third-party/LLaVA --no-deps
fi

# not included in the pip install package, but needed in llava
pip install protobuf
Expand All @@ -17,15 +31,14 @@ pip install protobuf
# Reinstall bitsandbytes to make it compatible.
pip install bitsandbytes -I

# numpy needs to be pin to 1.24. 1.26.4 will error out
pip install numpy==1.24

# The deps of llava can have different versions than deps of ExecuTorch.
# For example, torch version required from llava is older than ExecuTorch.
# To make both work, recover ExecuTorch's original dependencies by rerunning
# the install_requirements.sh.
# the install_requirements.sh. Notice this won't install executorch.
bash -x ./install_requirements.sh --pybind xnnpack

# Newer transformer will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'
# Newer transformer (4.38) will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'
pip install timm==0.6.13
pip install transformers==4.38.2
pip install transformers==4.37.2

pip list
66 changes: 0 additions & 66 deletions examples/models/llava/main.py

This file was deleted.

5 changes: 3 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torchvision
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer

from executorch.examples.models.llama2.source_transformation.sdpa import (
replace_sdpa_with_custom_op,
)
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
max_batch_size=1, # doesn't work with default batch size 32
ffn_dim_multiplier=1, # TODO: a hack to make rotary embedding happy
enable_dynamic_shape=True, # allow parallel prefill
use_sdpa_with_kv_cache_op=True,
use_sdpa_with_kv_cache_op=True, # use sdpa_with_kv_cache op
use_hf_rope=True,
)
self.embed_tokens = nn.Embedding(
Expand All @@ -81,7 +82,7 @@ def __init__(
self.model_.config.pad_token_id,
)
self.text_model = Transformer(self.text_model_args)
# use custom op for SDPA
# use custom op for SDPA.
self.text_model = replace_sdpa_with_custom_op(self.text_model)
# load state dict
self.text_model.load_state_dict(
Expand Down
68 changes: 68 additions & 0 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

import torch

from executorch.examples.models.llava.model import LlavaModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TestLlava(unittest.TestCase):
def setUp(self):
self.llava_model = LlavaModel()
self.llava = self.llava_model.get_eager_model()
self.prompt_before_image, self.resized, self.prompt_after_image = (
self.llava_model.get_inputs_for_prefill()
)

def test_prefill_logits(self):
prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))

def test_generated_output(self):
# source of truth, using HF llava
preprocessed = self.llava.image_preprocess(self.resized)
with torch.inference_mode():
output_ids = self.llava_model.model.generate(
self.llava_model.input_ids,
images=preprocessed,
image_sizes=[preprocessed.size],
do_sample=False,
num_beams=1,
max_new_tokens=5,
use_cache=True,
)

ref_outputs = self.llava_model.tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)[0].strip()

# being tested, using llama_transformer
prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
context_len = prefill_logits.shape[1]
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
for i in range(4):
logits = self.llava.step(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
)
new_tokens.append(torch.argmax(logits[-1, :]).item())

outputs = self.llava_model.tokenizer.batch_decode(
torch.tensor([new_tokens]), skip_special_tokens=True
)[0].strip()
self.assertEqual(outputs, ref_outputs)
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ addopts =
sdk/
# examples
examples/models/llama2/tests
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
exir/backend/test
Expand Down

0 comments on commit 5d58203

Please sign in to comment.