diff --git a/src/igl/vulkan/Device.cpp b/src/igl/vulkan/Device.cpp index cc60ad3642..5cd600eb20 100644 --- a/src/igl/vulkan/Device.cpp +++ b/src/igl/vulkan/Device.cpp @@ -267,7 +267,11 @@ std::shared_ptr Device::createShaderModule(const void* data, // @fb-only // @lint-ignore CLANGTIDY - return std::make_shared(ctx_->vf_, device, vkShaderModule); + return std::make_shared( + ctx_->vf_, + device, + vkShaderModule, + util::getReflectionData(reinterpret_cast(data), length)); } std::shared_ptr Device::createShaderModule(ShaderStage stage, @@ -342,9 +346,13 @@ std::shared_ptr Device::createShaderModule(ShaderStage stage glslang_resource_t glslangResource; ivkGlslangResource(&glslangResource, &ctx_->getVkPhysicalDeviceProperties()); + std::vector spirv; + const Result result = + igl::vulkan::compileShader(ctx_->vf_, device, vkStage, source, spirv, &glslangResource); + VkShaderModule vkShaderModule = VK_NULL_HANDLE; - const Result result = igl::vulkan::compileShader( - ctx_->vf_, device, vkStage, source, &vkShaderModule, &glslangResource); + VK_ASSERT(ivkCreateShaderModuleFromSPIRV( + &ctx_->vf_, device, spirv.data(), spirv.size() * sizeof(uint32_t), &vkShaderModule)); Result::setResult(outResult, result); @@ -363,7 +371,8 @@ std::shared_ptr Device::createShaderModule(ShaderStage stage // @fb-only // @lint-ignore CLANGTIDY - return std::make_shared(ctx_->vf_, device, vkShaderModule); + return std::make_shared( + ctx_->vf_, device, vkShaderModule, util::getReflectionData(spirv.data(), spirv.size())); } std::shared_ptr Device::createFramebuffer(const FramebufferDesc& desc, diff --git a/src/igl/vulkan/VulkanHelpers.c b/src/igl/vulkan/VulkanHelpers.c index 520e0a9984..f467ee6d55 100644 --- a/src/igl/vulkan/VulkanHelpers.c +++ b/src/igl/vulkan/VulkanHelpers.c @@ -1260,18 +1260,6 @@ glslang_input_t ivkGetGLSLangInput(VkShaderStageFlagBits stage, return input; } -VkResult ivkCreateShaderModule(const struct VulkanFunctionTable* vt, - VkDevice device, - glslang_program_t* program, - VkShaderModule* outShaderModule) { - const VkShaderModuleCreateInfo ci = { - .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, - .codeSize = glslang_program_SPIRV_get_size(program) * sizeof(uint32_t), - .pCode = glslang_program_SPIRV_get_ptr(program), - }; - return vt->vkCreateShaderModule(device, &ci, NULL, outShaderModule); -} - VkResult ivkCreateShaderModuleFromSPIRV(const struct VulkanFunctionTable* vt, VkDevice device, const void* dataSPIRV, diff --git a/src/igl/vulkan/VulkanHelpers.h b/src/igl/vulkan/VulkanHelpers.h index a3f5356e7e..e6b0bbc6af 100644 --- a/src/igl/vulkan/VulkanHelpers.h +++ b/src/igl/vulkan/VulkanHelpers.h @@ -159,11 +159,6 @@ VkResult ivkCreateRenderPass(const struct VulkanFunctionTable* vt, const VkRenderPassMultiviewCreateInfo* renderPassMultiview, VkRenderPass* outRenderPass); -VkResult ivkCreateShaderModule(const struct VulkanFunctionTable* vt, - VkDevice device, - glslang_program_t* program, - VkShaderModule* outShaderModule); - VkResult ivkCreateShaderModuleFromSPIRV(const struct VulkanFunctionTable* vt, VkDevice device, const void* dataSPIRV, diff --git a/src/igl/vulkan/VulkanShaderModule.cpp b/src/igl/vulkan/VulkanShaderModule.cpp index 32208d0acf..b0860fe159 100644 --- a/src/igl/vulkan/VulkanShaderModule.cpp +++ b/src/igl/vulkan/VulkanShaderModule.cpp @@ -60,14 +60,10 @@ Result compileShader(const VulkanFunctionTable& vf, VkDevice device, VkShaderStageFlagBits stage, const char* code, - VkShaderModule* outShaderModule, + std::vector& outSPIRV, const glslang_resource_t* glslLangResource) { IGL_PROFILER_FUNCTION(); - if (!outShaderModule) { - return Result(Result::Code::ArgumentNull, "outShaderModule is NULL"); - } - const glslang_input_t input = ivkGetGLSLangInput(stage, glslLangResource, code); glslang_shader_t* shader = glslang_shader_create(&input); @@ -125,15 +121,18 @@ Result compileShader(const VulkanFunctionTable& vf, IGL_LOG_ERROR("%s\n", glslang_program_SPIRV_get_messages(program)); } - VK_ASSERT_RETURN(ivkCreateShaderModule(&vf, device, program, outShaderModule)); + const unsigned int* codePtr = glslang_program_SPIRV_get_ptr(program); + + outSPIRV = std::vector(codePtr, codePtr + glslang_program_SPIRV_get_size(program)); return Result(); } VulkanShaderModule::VulkanShaderModule(const VulkanFunctionTable& vf, VkDevice device, - VkShaderModule shaderModule) : - vf_(vf), device_(device), vkShaderModule_(shaderModule) {} + VkShaderModule shaderModule, + util::SpvModuleInfo&& moduleInfo) : + vf_(vf), device_(device), vkShaderModule_(shaderModule), moduleInfo_(std::move(moduleInfo)) {} VulkanShaderModule::~VulkanShaderModule() { vf_.vkDestroyShaderModule(device_, vkShaderModule_, nullptr); diff --git a/src/igl/vulkan/VulkanShaderModule.h b/src/igl/vulkan/VulkanShaderModule.h index 85588881bf..ac2587d6b6 100644 --- a/src/igl/vulkan/VulkanShaderModule.h +++ b/src/igl/vulkan/VulkanShaderModule.h @@ -8,10 +8,12 @@ #pragma once #include +#include #include #include #include +#include namespace igl { namespace vulkan { @@ -20,7 +22,7 @@ Result compileShader(const VulkanFunctionTable& vf, VkDevice device, VkShaderStageFlagBits stage, const char* code, - VkShaderModule* outShaderModule, + std::vector& outSPIRV, const glslang_resource_t* glslLangResource = nullptr); /** @@ -29,7 +31,10 @@ Result compileShader(const VulkanFunctionTable& vf, class VulkanShaderModule final { public: /** @brief Instantiates a shader module wrapper with the module and the device that owns it */ - VulkanShaderModule(const VulkanFunctionTable& vf, VkDevice device, VkShaderModule shaderModule); + VulkanShaderModule(const VulkanFunctionTable& vf, + VkDevice device, + VkShaderModule shaderModule, + util::SpvModuleInfo&& moduleInfo); ~VulkanShaderModule(); /** @brief Returns the underlying Vulkan shader module */ @@ -41,6 +46,7 @@ class VulkanShaderModule final { const VulkanFunctionTable& vf_; VkDevice device_ = VK_NULL_HANDLE; VkShaderModule vkShaderModule_ = VK_NULL_HANDLE; + util::SpvModuleInfo moduleInfo_ = {}; }; } // namespace vulkan