Skip to content

Commit

Permalink
[llama-mm] Make text decoder exportable (pytorch#6999)
Browse files Browse the repository at this point in the history
* [llama-mm] Make text decoder exportable

Summary: Adds source transformation and changes example input to make
text decoder exportable.

Test Plan: Added new unit test.

Reviewers:

Subscribers:

Tasks:

Tags:

* Make the test model smaller

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Ignore e2e test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Do not run e2e test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Address comments
  • Loading branch information
larryliu0820 authored Nov 21, 2024
1 parent 2adb1bc commit 8d71cd3
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 15 deletions.
37 changes: 24 additions & 13 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from executorch.examples.models.model_base import EagerModelBase
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, **kwargs):
self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.verbose = kwargs.get("verbose", False)
self.args = kwargs.get("args", None)
self.dtype = None
self.dtype = kwargs.get("dtype", torch.float16)
self.use_checkpoint = False

ckpt_dir = get_default_model_resource_dir(__file__)
Expand All @@ -72,7 +73,7 @@ def __init__(self, **kwargs):
dtype=torch.bool,
)
)
self.input_pos = torch.arange(self.max_seq_len)
self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)

# Load checkpoint and params.
device = "cpu"
Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(self, **kwargs):
rope_base=params["rope_theta"],
intermediate_dim=params["intermediate_dim"],
)

# Source transformation for MultiHeadAttention
self.model_ = replace_mha_with_inference_mha(self.model_)
# Save params for future use.
for param_name, param_val in params.items():
setattr(self.model_, param_name, param_val)
Expand Down Expand Up @@ -147,39 +151,46 @@ def __init__(self, **kwargs):
self.model_.setup_caches(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=self.encoder_max_seq_len,
decoder_max_seq_len=self.max_seq_len,
)
# number of tokens for example input
self.n_tokens = 34
self.model_.to(self.dtype)

def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
return self.model_.to(self.dtype)
else:
return self.model_.to(torch.float16)
return self.model_

def get_example_inputs(self):
return (torch.ones(1, 32, dtype=torch.long),)
return (torch.ones(1, self.n_tokens, dtype=torch.int64),)

def get_example_kwarg_inputs(self):
# For export we must use the prefill versions of the
# causal mask and input_pos.
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
if self.use_kv_cache:
return {
"input_pos": self.input_pos[None, :32],
"mask": self.causal_mask[None, :32],
# "encoder_input": None,
# "encoder_mask": None,
"input_pos": self.input_pos[None, : self.n_tokens],
"mask": self.causal_mask[None, : self.n_tokens],
"encoder_input": torch.randn(
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
),
"encoder_mask": torch.ones(
[1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool
),
}
else:
return None

def get_dynamic_shapes(self):
batch_size = 1
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
if self.use_kv_cache:
dynamic_shapes = {
"tokens": {0: batch_size, 1: dim_seq_len},
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
"encoder_input": None,
"encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
"input_pos": {0: batch_size, 1: dim_seq_len},
}
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
# Only test AOTI in this file
import json
import os
import tempfile
import unittest

import torch

from executorch.examples.models.llama3_2_vision.text_decoder.model import (
Llama3_2Decoder,
)
from torch.testing import assert_close

params = {
"dim": 2048,
"ffn_dim_multiplier": 1.3,
"fusion_interval": 2,
"intermediate_dim": 14336,
"multiple_of": 1024,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 2,
"n_special_tokens": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"use_scaled_rope": True,
"vision_chunk_size": 560,
"vision_max_num_chunks": 4,
"vocab_size": 21008,
"vision_num_cross_attention_layers": 1,
}


class TextDecoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def _set_requires_grad_false(self, model: torch.nn.Module) -> None:
for param in model.parameters():
param.requires_grad = False
for child in model.children():
self._set_requires_grad_false(child)

def test_llama3_2_text_decoder_aoti(self) -> None:
with tempfile.NamedTemporaryFile(mode="w") as param_file:
json.dump(params, param_file, indent=2)
param_file.flush()
model = Llama3_2Decoder(
encoder_max_seq_len=6404,
generate_full_logits=True,
enable_dynamic_shape=True,
use_kv_cache=True,
params=param_file.name,
dtype=torch.float32,
)
encoder = model.get_eager_model().eval()
self._set_requires_grad_false(encoder)

# AOTI
with torch.no_grad(), torch.inference_mode():
ep = torch.export.export(
encoder,
model.get_example_inputs(),
kwargs=model.get_example_kwarg_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
)
with tempfile.TemporaryDirectory() as tmpdir:
path = torch._inductor.aoti_compile_and_package(
ep,
model.get_example_inputs(),
kwargs=model.get_example_kwarg_inputs(),
package_path=os.path.join(tmpdir, "text_decoder.pt2"),
)
encoder_aoti = torch._inductor.aoti_load_package(path)

y = encoder_aoti(
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
)

eager_res = encoder.forward(
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
)
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)
5 changes: 3 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ addopts =
examples/models/llama/tests
examples/models/llama3_2_vision/preprocess
examples/models/llama3_2_vision/vision_encoder/test
examples/models/llama3_2_vision/text_decoder/test
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
Expand All @@ -43,8 +44,8 @@ addopts =
extension/pybindings/test
# Runtime
runtime
# test
test/end2end/test_end2end.py
# test TODO: fix these tests
# test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
--ignore=backends/xnnpack/test/models/llama2_et_example.py
# T200992559: Add torchao to ET as core dependency
Expand Down

0 comments on commit 8d71cd3

Please sign in to comment.