Skip to content

Commit

Permalink
igl | vulkan | Store reflection info for shaders
Browse files Browse the repository at this point in the history
Summary: Store reflection info for all shaders. Reflection info is obtained from SPIR-V, so it works for GLSL and SparkSL shaders.

Reviewed By: EricGriffith

Differential Revision: D50368271

fbshipit-source-id: 510fdf52022860a8e0db2509c3c1e5ee8b0529f1
  • Loading branch information
corporateshark authored and facebook-github-bot committed Oct 22, 2023
1 parent ef8e69a commit 2b0e793
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 31 deletions.
17 changes: 13 additions & 4 deletions src/igl/vulkan/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ std::shared_ptr<VulkanShaderModule> Device::createShaderModule(const void* data,

// @fb-only
// @lint-ignore CLANGTIDY
return std::make_shared<VulkanShaderModule>(ctx_->vf_, device, vkShaderModule);
return std::make_shared<VulkanShaderModule>(
ctx_->vf_,
device,
vkShaderModule,
util::getReflectionData(reinterpret_cast<const uint32_t*>(data), length));
}

std::shared_ptr<VulkanShaderModule> Device::createShaderModule(ShaderStage stage,
Expand Down Expand Up @@ -342,9 +346,13 @@ std::shared_ptr<VulkanShaderModule> Device::createShaderModule(ShaderStage stage
glslang_resource_t glslangResource;
ivkGlslangResource(&glslangResource, &ctx_->getVkPhysicalDeviceProperties());

std::vector<uint32_t> spirv;
const Result result =
igl::vulkan::compileShader(ctx_->vf_, device, vkStage, source, spirv, &glslangResource);

VkShaderModule vkShaderModule = VK_NULL_HANDLE;
const Result result = igl::vulkan::compileShader(
ctx_->vf_, device, vkStage, source, &vkShaderModule, &glslangResource);
VK_ASSERT(ivkCreateShaderModuleFromSPIRV(
&ctx_->vf_, device, spirv.data(), spirv.size() * sizeof(uint32_t), &vkShaderModule));

Result::setResult(outResult, result);

Expand All @@ -363,7 +371,8 @@ std::shared_ptr<VulkanShaderModule> Device::createShaderModule(ShaderStage stage

// @fb-only
// @lint-ignore CLANGTIDY
return std::make_shared<VulkanShaderModule>(ctx_->vf_, device, vkShaderModule);
return std::make_shared<VulkanShaderModule>(
ctx_->vf_, device, vkShaderModule, util::getReflectionData(spirv.data(), spirv.size()));
}

std::shared_ptr<IFramebuffer> Device::createFramebuffer(const FramebufferDesc& desc,
Expand Down
12 changes: 0 additions & 12 deletions src/igl/vulkan/VulkanHelpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -1260,18 +1260,6 @@ glslang_input_t ivkGetGLSLangInput(VkShaderStageFlagBits stage,
return input;
}

VkResult ivkCreateShaderModule(const struct VulkanFunctionTable* vt,
VkDevice device,
glslang_program_t* program,
VkShaderModule* outShaderModule) {
const VkShaderModuleCreateInfo ci = {
.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
.codeSize = glslang_program_SPIRV_get_size(program) * sizeof(uint32_t),
.pCode = glslang_program_SPIRV_get_ptr(program),
};
return vt->vkCreateShaderModule(device, &ci, NULL, outShaderModule);
}

VkResult ivkCreateShaderModuleFromSPIRV(const struct VulkanFunctionTable* vt,
VkDevice device,
const void* dataSPIRV,
Expand Down
5 changes: 0 additions & 5 deletions src/igl/vulkan/VulkanHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@ VkResult ivkCreateRenderPass(const struct VulkanFunctionTable* vt,
const VkRenderPassMultiviewCreateInfo* renderPassMultiview,
VkRenderPass* outRenderPass);

VkResult ivkCreateShaderModule(const struct VulkanFunctionTable* vt,
VkDevice device,
glslang_program_t* program,
VkShaderModule* outShaderModule);

VkResult ivkCreateShaderModuleFromSPIRV(const struct VulkanFunctionTable* vt,
VkDevice device,
const void* dataSPIRV,
Expand Down
15 changes: 7 additions & 8 deletions src/igl/vulkan/VulkanShaderModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,10 @@ Result compileShader(const VulkanFunctionTable& vf,
VkDevice device,
VkShaderStageFlagBits stage,
const char* code,
VkShaderModule* outShaderModule,
std::vector<uint32_t>& outSPIRV,
const glslang_resource_t* glslLangResource) {
IGL_PROFILER_FUNCTION();

if (!outShaderModule) {
return Result(Result::Code::ArgumentNull, "outShaderModule is NULL");
}

const glslang_input_t input = ivkGetGLSLangInput(stage, glslLangResource, code);

glslang_shader_t* shader = glslang_shader_create(&input);
Expand Down Expand Up @@ -125,15 +121,18 @@ Result compileShader(const VulkanFunctionTable& vf,
IGL_LOG_ERROR("%s\n", glslang_program_SPIRV_get_messages(program));
}

VK_ASSERT_RETURN(ivkCreateShaderModule(&vf, device, program, outShaderModule));
const unsigned int* codePtr = glslang_program_SPIRV_get_ptr(program);

outSPIRV = std::vector(codePtr, codePtr + glslang_program_SPIRV_get_size(program));

return Result();
}

VulkanShaderModule::VulkanShaderModule(const VulkanFunctionTable& vf,
VkDevice device,
VkShaderModule shaderModule) :
vf_(vf), device_(device), vkShaderModule_(shaderModule) {}
VkShaderModule shaderModule,
util::SpvModuleInfo&& moduleInfo) :
vf_(vf), device_(device), vkShaderModule_(shaderModule), moduleInfo_(std::move(moduleInfo)) {}

VulkanShaderModule::~VulkanShaderModule() {
vf_.vkDestroyShaderModule(device_, vkShaderModule_, nullptr);
Expand Down
10 changes: 8 additions & 2 deletions src/igl/vulkan/VulkanShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#pragma once

#include <memory>
#include <vector>

#include <igl/vulkan/Common.h>
#include <igl/vulkan/VulkanFunctions.h>
#include <igl/vulkan/VulkanHelpers.h>
#include <igl/vulkan/util/SpvReflection.h>

namespace igl {
namespace vulkan {
Expand All @@ -20,7 +22,7 @@ Result compileShader(const VulkanFunctionTable& vf,
VkDevice device,
VkShaderStageFlagBits stage,
const char* code,
VkShaderModule* outShaderModule,
std::vector<uint32_t>& outSPIRV,
const glslang_resource_t* glslLangResource = nullptr);

/**
Expand All @@ -29,7 +31,10 @@ Result compileShader(const VulkanFunctionTable& vf,
class VulkanShaderModule final {
public:
/** @brief Instantiates a shader module wrapper with the module and the device that owns it */
VulkanShaderModule(const VulkanFunctionTable& vf, VkDevice device, VkShaderModule shaderModule);
VulkanShaderModule(const VulkanFunctionTable& vf,
VkDevice device,
VkShaderModule shaderModule,
util::SpvModuleInfo&& moduleInfo);
~VulkanShaderModule();

/** @brief Returns the underlying Vulkan shader module */
Expand All @@ -41,6 +46,7 @@ class VulkanShaderModule final {
const VulkanFunctionTable& vf_;
VkDevice device_ = VK_NULL_HANDLE;
VkShaderModule vkShaderModule_ = VK_NULL_HANDLE;
util::SpvModuleInfo moduleInfo_ = {};
};

} // namespace vulkan
Expand Down

0 comments on commit 2b0e793

Please sign in to comment.