diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 8c3943cbcb..2d9c41b603 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -17,6 +17,7 @@ ) from executorch.examples.models.model_base import EagerModelBase +from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune @@ -53,7 +54,7 @@ def __init__(self, **kwargs): self.use_kv_cache = kwargs.get("use_kv_cache", False) self.verbose = kwargs.get("verbose", False) self.args = kwargs.get("args", None) - self.dtype = None + self.dtype = kwargs.get("dtype", torch.float16) self.use_checkpoint = False ckpt_dir = get_default_model_resource_dir(__file__) @@ -72,7 +73,7 @@ def __init__(self, **kwargs): dtype=torch.bool, ) ) - self.input_pos = torch.arange(self.max_seq_len) + self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64) # Load checkpoint and params. device = "cpu" @@ -107,6 +108,9 @@ def __init__(self, **kwargs): rope_base=params["rope_theta"], intermediate_dim=params["intermediate_dim"], ) + + # Source transformation for MultiHeadAttention + self.model_ = replace_mha_with_inference_mha(self.model_) # Save params for future use. for param_name, param_val in params.items(): setattr(self.model_, param_name, param_val) @@ -147,27 +151,33 @@ def __init__(self, **kwargs): self.model_.setup_caches( batch_size=1, dtype=self.dtype, + encoder_max_seq_len=self.encoder_max_seq_len, decoder_max_seq_len=self.max_seq_len, ) + # number of tokens for example input + self.n_tokens = 34 + self.model_.to(self.dtype) def get_eager_model(self) -> torch.nn.Module: - if self.dtype: - return self.model_.to(self.dtype) - else: - return self.model_.to(torch.float16) + return self.model_ def get_example_inputs(self): - return (torch.ones(1, 32, dtype=torch.long),) + return (torch.ones(1, self.n_tokens, dtype=torch.int64),) def get_example_kwarg_inputs(self): # For export we must use the prefill versions of the # causal mask and input_pos. + # Hardcoding # of tiles to be 2. image tokens per tile is 1601. if self.use_kv_cache: return { - "input_pos": self.input_pos[None, :32], - "mask": self.causal_mask[None, :32], - # "encoder_input": None, - # "encoder_mask": None, + "input_pos": self.input_pos[None, : self.n_tokens], + "mask": self.causal_mask[None, : self.n_tokens], + "encoder_input": torch.randn( + 1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype + ), + "encoder_mask": torch.ones( + [1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool + ), } else: return None @@ -175,11 +185,12 @@ def get_example_kwarg_inputs(self): def get_dynamic_shapes(self): batch_size = 1 dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) + # Hardcoding # of tiles to be 2. image tokens per tile is 1601. if self.use_kv_cache: dynamic_shapes = { "tokens": {0: batch_size, 1: dim_seq_len}, - # "encoder_input": {0: 1, 1: dim_enc, 2: 4096}, - # "encoder_mask": {0: 1, 1: dim, 2: dim_enc}, + "encoder_input": None, + "encoder_mask": {0: 1, 1: dim_seq_len, 2: None}, "mask": {0: batch_size, 1: dim_seq_len, 2: None}, "input_pos": {0: batch_size, 1: dim_seq_len}, } diff --git a/examples/models/llama3_2_vision/text_decoder/test/__init__.py b/examples/models/llama3_2_vision/text_decoder/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py b/examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py new file mode 100644 index 0000000000..8e678801b8 --- /dev/null +++ b/examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh. +# Only test AOTI in this file +import json +import os +import tempfile +import unittest + +import torch + +from executorch.examples.models.llama3_2_vision.text_decoder.model import ( + Llama3_2Decoder, +) +from torch.testing import assert_close + +params = { + "dim": 2048, + "ffn_dim_multiplier": 1.3, + "fusion_interval": 2, + "intermediate_dim": 14336, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 2, + "n_special_tokens": 8, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vocab_size": 21008, + "vision_num_cross_attention_layers": 1, +} + + +class TextDecoderTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def _set_requires_grad_false(self, model: torch.nn.Module) -> None: + for param in model.parameters(): + param.requires_grad = False + for child in model.children(): + self._set_requires_grad_false(child) + + def test_llama3_2_text_decoder_aoti(self) -> None: + with tempfile.NamedTemporaryFile(mode="w") as param_file: + json.dump(params, param_file, indent=2) + param_file.flush() + model = Llama3_2Decoder( + encoder_max_seq_len=6404, + generate_full_logits=True, + enable_dynamic_shape=True, + use_kv_cache=True, + params=param_file.name, + dtype=torch.float32, + ) + encoder = model.get_eager_model().eval() + self._set_requires_grad_false(encoder) + + # AOTI + with torch.no_grad(), torch.inference_mode(): + ep = torch.export.export( + encoder, + model.get_example_inputs(), + kwargs=model.get_example_kwarg_inputs(), + dynamic_shapes=model.get_dynamic_shapes(), + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = torch._inductor.aoti_compile_and_package( + ep, + model.get_example_inputs(), + kwargs=model.get_example_kwarg_inputs(), + package_path=os.path.join(tmpdir, "text_decoder.pt2"), + ) + encoder_aoti = torch._inductor.aoti_load_package(path) + + y = encoder_aoti( + *model.get_example_inputs(), **model.get_example_kwarg_inputs() + ) + + eager_res = encoder.forward( + *model.get_example_inputs(), **model.get_example_kwarg_inputs() + ) + assert_close(y, eager_res, rtol=1e-4, atol=1e-4) diff --git a/pytest.ini b/pytest.ini index 5c6ca90fc7..3d1adccf2e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,6 +18,7 @@ addopts = examples/models/llama/tests examples/models/llama3_2_vision/preprocess examples/models/llama3_2_vision/vision_encoder/test + examples/models/llama3_2_vision/text_decoder/test # examples/models/llava/test TODO: enable this # exir exir/_serialize/test @@ -43,8 +44,8 @@ addopts = extension/pybindings/test # Runtime runtime - # test - test/end2end/test_end2end.py + # test TODO: fix these tests + # test/end2end/test_end2end.py --ignore=backends/xnnpack/test/ops/linear.py --ignore=backends/xnnpack/test/models/llama2_et_example.py # T200992559: Add torchao to ET as core dependency