diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index c8230208dbbd..9fab372bc34e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -182,14 +182,42 @@ class VulkanSPIRVTargetBackend : public TargetBackend { return variantOp.emitError() << "should contain some spirv.module ops"; } + DenseMap<StringRef, uint64_t> entryPointOrdinals; + + SmallVector<IREE::HAL::ExecutableExportOp> exportOps = + llvm::to_vector(variantOp.getOps<IREE::HAL::ExecutableExportOp>()); + for (auto exportOp : exportOps) { + uint64_t ordinal = 0; + if (std::optional<APInt> optionalOrdinal = exportOp.getOrdinal()) { + ordinal = optionalOrdinal->getZExtValue(); + } else { + // For executables with only one entry point, linking doesn't kick in at + // all. So the ordinal can be missing for this case. + if (!llvm::hasSingleElement(exportOps)) { + return exportOp.emitError() << "should have ordinal attribute"; + } + } + entryPointOrdinals[exportOp.getSymName()] = ordinal; + } + uint64_t ordinalCount = entryPointOrdinals.size(); + FlatbufferBuilder builder; iree_hal_spirv_ExecutableDef_start_as_root(builder); + // The list of shader modules. + SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs; + + // Per entry-point data. + // Note that the following vectors should all be of the same size and + // element at index #i is for entry point with ordinal #i! SmallVector<StringRef> entryPointNames; SmallVector<uint32_t> subgroupSizes; - SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs; SmallVector<uint32_t> shaderModuleIndices; SmallVector<iree_hal_spirv_FileLineLocDef_ref_t> sourceLocationRefs; + entryPointNames.resize(ordinalCount); + subgroupSizes.resize(ordinalCount); + shaderModuleIndices.resize(ordinalCount); + bool hasAnySubgroupSizes = false; // Iterate over all spirv.module ops and encode them into the FlatBuffer @@ -202,6 +230,7 @@ class VulkanSPIRVTargetBackend : public TargetBackend { << "expected to contain exactly one entry point"; } spirv::EntryPointOp spvEntryPoint = *spirvEntryPoints.begin(); + uint64_t ordinal = entryPointOrdinals.at(spvEntryPoint.getFn()); if (!options.dumpIntermediatesPath.empty()) { std::string assembly; @@ -223,31 +252,34 @@ class VulkanSPIRVTargetBackend : public TargetBackend { } auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(), spvBinary.size()); - shaderModuleIndices.push_back(shaderModuleRefs.size()); + shaderModuleIndices[ordinal] = shaderModuleRefs.size(); shaderModuleRefs.push_back( iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef)); // The IREE runtime uses ordinals instead of names. We need to attach the // entry point name for VkShaderModuleCreateInfo. - entryPointNames.push_back(spvEntryPoint.getFn()); + entryPointNames[ordinal] = spvEntryPoint.getFn(); // If there are subgroup size requests, we need to pick up too. auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(spvEntryPoint.getFn()); auto abi = fn->getAttrOfType<spirv::EntryPointABIAttr>( spirv::getEntryPointABIAttrName()); if (abi && abi.getSubgroupSize()) { - subgroupSizes.push_back(*abi.getSubgroupSize()); + subgroupSizes[ordinal] = *abi.getSubgroupSize(); hasAnySubgroupSizes = true; } else { - subgroupSizes.push_back(0); + subgroupSizes[ordinal] = 0; } // Optional source location information for debugging/profiling. if (options.debugLevel >= 1) { if (auto loc = findFirstFileLoc(spvEntryPoint.getLoc())) { + // We only ever resize to the maximum -- so all previous data will be + // kept as-is. + sourceLocationRefs.resize(ordinalCount); auto filenameRef = builder.createString(loc->getFilename()); - sourceLocationRefs.push_back(iree_hal_spirv_FileLineLocDef_create( - builder, filenameRef, loc->getLine())); + sourceLocationRefs[ordinal] = iree_hal_spirv_FileLineLocDef_create( + builder, filenameRef, loc->getLine()); } }; }