diff --git a/exir/dim_order_utils.py b/exir/dim_order_utils.py index 562244b6a4..a0551c6f4d 100644 --- a/exir/dim_order_utils.py +++ b/exir/dim_order_utils.py @@ -61,3 +61,21 @@ def get_dim_order( raise AssertionError( f"Failed to generate dim_order for a given memory format: {memory_format}" ) + + +def is_channel_last_dim_order(tensor: torch.Tensor) -> bool: + """ + Check if a tensor has channels last dim order + """ + if tensor.dim() != 4: + # Only support 4D tensors for channel list memory format. + return False + + return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim())) + + +def is_contiguous_dim_order(tensor: torch.Tensor) -> bool: + """ + Check if a tensor has contiguous dim order + """ + return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim())) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 2f251ec8bf..15e73dd413 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -10,6 +10,11 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.exir.dim_order_utils import ( + is_channel_last_dim_order, + is_contiguous_dim_order, +) from torch.export import export from torch.testing import FileCheck @@ -22,15 +27,6 @@ class MemoryFormatTestSet: class TestMemoryFormatOpsPass(unittest.TestCase): - def is_channel_last(self, x: torch.Tensor): - # This is a heuristic to determine if the input tensor is in NHWC (channel last) - # due to we do not have a good way to infer the dimension order or the memory format - # of the input tensor. Please not this function is specific for contiguous tensors - # whose dim(1) is channel one only, other types of tensors may not work well - # due to different channel configuration and memory arrangement. - - return x.stride(1) == 1 - def memory_format_test_runner(self, test_set: MemoryFormatTestSet): aten_op_str = "torch.ops.aten._to_copy.default" edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" @@ -60,13 +56,13 @@ def memory_format_test_runner(self, test_set: MemoryFormatTestSet): actual = epm.exported_program().module()(*test_set.sample_input) self.assertTrue(torch.allclose(actual, expected)) self.assertEqual( - self.is_channel_last(actual), - self.is_channel_last(expected), + is_channel_last_dim_order(actual), + is_channel_last_dim_order(expected), ) if test_set.target_memory_format == torch.channels_last: - self.assertTrue(self.is_channel_last(actual)) + self.assertTrue(is_channel_last_dim_order(actual)) elif test_set.target_memory_format == torch.contiguous_format: - self.assertFalse(self.is_channel_last(actual)) + self.assertTrue(is_contiguous_dim_order(actual)) else: raise RuntimeError("Unknown memory format")