Skip to content

Commit

Permalink
Fix flaky ET attention test (pytorch#6795)
Browse files Browse the repository at this point in the history
* Fix flaky ET attention test

* Use assert_close

* Remove msg from assert_close

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: Mengwei Liu <[email protected]>
  • Loading branch information
dvorjackz and larryliu0820 authored Nov 12, 2024
1 parent b6ebd3c commit f943856
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
11 changes: 7 additions & 4 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MultiHeadAttention as ETMultiHeadAttention,
)
from executorch.runtime import Runtime
from torch.testing import assert_close
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention

Expand Down Expand Up @@ -94,7 +95,7 @@ def test_attention_eager(self):
et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
Expand Down Expand Up @@ -124,7 +125,8 @@ def test_attention_eager(self):
tt_res = self.tt_mha(
self.x, self.x, input_pos=next_input_pos
) # Self attention with input pos.
self.assertTrue(torch.allclose(et_res, tt_res))

assert_close(et_res, tt_res)

def test_attention_export(self):
# Self attention.
Expand All @@ -136,7 +138,8 @@ def test_attention_export(self):
)
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
self.assertTrue(torch.allclose(et_res, tt_res))

assert_close(et_res, tt_res)

# TODO: KV cache.

Expand All @@ -162,6 +165,6 @@ def test_attention_executorch(self):
et_res = method.execute((self.x, self.x, self.input_pos))
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
assert_close(et_res[0], tt_res)

# TODO: KV cache.
1 change: 0 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ addopts =
backends/xnnpack/test
# extension/
extension/llm/modules/test
--ignore=extension/llm/modules/test/test_mha.py
extension/pybindings/test
# Runtime
runtime
Expand Down

0 comments on commit f943856

Please sign in to comment.