diff --git a/backends/xnnpack/test/ops/cat.py b/backends/xnnpack/test/ops/cat.py index 1eb38b6828..cb6496e4ab 100644 --- a/backends/xnnpack/test/ops/cat.py +++ b/backends/xnnpack/test/ops/cat.py @@ -21,17 +21,18 @@ def forward(self, x, y): def test_fp32_cat(self): inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3)) ( - Tester(self.Cat(), inputs).export() - # .check_count({"torch.ops.aten.cat": 1}) - # .to_edge() - # .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) - # .partition() - # .check_count({"torch.ops.executorch_call_delegate": 1}) - # .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) - # .to_executorch() - # .serialize() - # .run_method() - # .compare_outputs() + Tester(self.Cat(), inputs) + .export() + .check_count({"torch.ops.aten.cat": 1}) + .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) + .partition() + .check_count({"torch.ops.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() ) class CatNegativeDim(torch.nn.Module):