diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 9ae136a213..565e8c67d7 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -13,6 +13,7 @@ MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime +from torch.testing import assert_close from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention @@ -94,7 +95,7 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # test with kv cache self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) @@ -124,7 +125,8 @@ def test_attention_eager(self): tt_res = self.tt_mha( self.x, self.x, input_pos=next_input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + + assert_close(et_res, tt_res) def test_attention_export(self): # Self attention. @@ -136,7 +138,8 @@ def test_attention_export(self): ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res, tt_res)) + + assert_close(et_res, tt_res) # TODO: KV cache. @@ -162,6 +165,6 @@ def test_attention_executorch(self): et_res = method.execute((self.x, self.x, self.input_pos)) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + assert_close(et_res[0], tt_res) # TODO: KV cache. diff --git a/pytest.ini b/pytest.ini index 03c015c397..a5041504ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -39,7 +39,6 @@ addopts = backends/xnnpack/test # extension/ extension/llm/modules/test - --ignore=extension/llm/modules/test/test_mha.py extension/pybindings/test # Runtime runtime