diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 72e4fc22b7..d16761a691 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -6,7 +6,7 @@ import ctypes import unittest -from typing import List, Optional, Tuple +from typing import Tuple import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -18,7 +18,6 @@ from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from executorch.exir.pass_base import ExportPass from torch.export import Dim, export, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -98,7 +97,6 @@ def lower_module_and_test_output( test_inputs=None, memory_layouts=None, first_output_only=False, - custom_pass: Optional[List[ExportPass]] = None, ): """ Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with @@ -120,8 +118,7 @@ def run_test(memory_layout): ) edge_program: EdgeProgramManager = to_edge(program) - if custom_pass is not None: - edge_program = edge_program.transform(custom_pass) + edge_program = edge_program.transform([MeanToSumDiv()]) edge_program = edge_program.to_backend(VulkanPartitioner(compile_options)) @@ -1344,35 +1341,30 @@ def forward(self, x): MeanModule(dims=[-1, -2]), sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], - custom_pass=[MeanToSumDiv()], ) self.lower_module_and_test_output( MeanModule(dims=[1]), sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], - custom_pass=[MeanToSumDiv()], ) self.lower_module_and_test_output( MeanModule(dims=[0, 1, 2, 3]), sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], - custom_pass=[MeanToSumDiv()], ) self.lower_module_and_test_output( MeanModule(dims=[-1, -2], keepdim=False), sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], - custom_pass=[MeanToSumDiv()], ) self.lower_module_and_test_output( MeanModule(dims=[1], keepdim=False), sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], - custom_pass=[MeanToSumDiv()], ) def test_vulkan_backend_index_select_int(self):