Skip to content

Commit

Permalink
Inline test custom_pass (pytorch#3895)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3895

`MeanToSumDiv()` and the upcoming `I64toI32()`  should be compatible with all ET-VK models. Hence, we apply them to all Python tests.
ghstack-source-id: 229355725
exported-using-ghexport

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: SS-JIA

Differential Revision: D58272547

fbshipit-source-id: fe9c923366281b6eebfec661b4e73c8fb4693292
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Jun 7, 2024
1 parent aadfc0f commit 91c0485
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ctypes
import unittest
from typing import List, Optional, Tuple
from typing import Tuple

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

Expand All @@ -18,7 +18,6 @@
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend

from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from executorch.exir.pass_base import ExportPass
from torch.export import Dim, export, ExportedProgram

ctypes.CDLL("libvulkan.so.1")
Expand Down Expand Up @@ -98,7 +97,6 @@ def lower_module_and_test_output(
test_inputs=None,
memory_layouts=None,
first_output_only=False,
custom_pass: Optional[List[ExportPass]] = None,
):
"""
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
Expand All @@ -120,8 +118,7 @@ def run_test(memory_layout):
)
edge_program: EdgeProgramManager = to_edge(program)

if custom_pass is not None:
edge_program = edge_program.transform(custom_pass)
edge_program = edge_program.transform([MeanToSumDiv()])

edge_program = edge_program.to_backend(VulkanPartitioner(compile_options))

Expand Down Expand Up @@ -1344,35 +1341,30 @@ def forward(self, x):
MeanModule(dims=[-1, -2]),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

self.lower_module_and_test_output(
MeanModule(dims=[1]),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

self.lower_module_and_test_output(
MeanModule(dims=[0, 1, 2, 3]),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

self.lower_module_and_test_output(
MeanModule(dims=[-1, -2], keepdim=False),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

self.lower_module_and_test_output(
MeanModule(dims=[1], keepdim=False),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

def test_vulkan_backend_index_select_int(self):
Expand Down

0 comments on commit 91c0485

Please sign in to comment.