From 1d7d71dbf225f1827d88e2990fe85fed2c57767e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 18 Jul 2024 10:20:15 -0700 Subject: [PATCH] Allow expression of scalar tensor buffers, non string values in variants (#4292) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4292 Some simple improvements to the SPIR-V compilation script: 1. Allow `layout_declare_tensor` to create a scalar buffer instead of always creating a vectorized buffer 2. Allow handling of non-string (i.e. int) values in shader codegen YAML configurations. Reviewed By: jorgep31415 Differential Revision: D59877805 fbshipit-source-id: 579888fbc19d19a0d24f2fbd831e74f4ba32f033 --- backends/vulkan/runtime/gen_vulkan_spv.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index aa32b9ab70..c9e3aaa31e 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -231,6 +231,7 @@ def layout_declare_tensor( var_name: str, dtype: str, storage_type: str, + is_scalar_array: bool = False, precision: str = "PRECISION", ) -> str: assert storage_type.lower() in ["buffer", "texture3d", "texture2d"] @@ -242,7 +243,12 @@ def layout_declare_tensor( # Create buffer binding if storage_type.lower() == "buffer": return layout_declare_buffer( - slot, access_type, var_name, dtype, precision, is_scalar_array=False + slot, + access_type, + var_name, + dtype, + precision, + is_scalar_array=is_scalar_array, ) # Create image/sampler binding @@ -533,7 +539,7 @@ def generateVariantCombinations( curr_suffix = ( suffix + "_" + str(i) if suffix else str(i) ) - param_values.append((param_name, curr_suffix, str(i))) + param_values.append((param_name, curr_suffix, i)) else: raise ValueError( f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)." @@ -595,7 +601,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: variant_name = variant["NAME"] for param_value in combination: default_params_copy[param_value[0]] = param_value[2] - if len(param_value[1]) > 0: + if len(str(param_value[1])) > 0: variant_name = f"{variant_name}_{param_value[1]}" default_params_copy["NAME"] = variant_name