Skip to content

Commit

Permalink
[spirv] Respect entry point ordinal when serializing executables (#15905
Browse files Browse the repository at this point in the history
)

We have assigned ordinals to various entry points when linking; need to
respect the order there when serializing. We were lucky before because
hal.executable.export ops are sorted in ascending order w.r.t. ordinals
thus far, and spirv.module ops follow the same order of
hal.executable.export ops. But it's not guaranteed to be so.
  • Loading branch information
antiagainst authored Dec 14, 2023
1 parent f81f361 commit 9726ead
Showing 1 changed file with 39 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,42 @@ class VulkanSPIRVTargetBackend : public TargetBackend {
return variantOp.emitError() << "should contain some spirv.module ops";
}

DenseMap<StringRef, uint64_t> entryPointOrdinals;

SmallVector<IREE::HAL::ExecutableExportOp> exportOps =
llvm::to_vector(variantOp.getOps<IREE::HAL::ExecutableExportOp>());
for (auto exportOp : exportOps) {
uint64_t ordinal = 0;
if (std::optional<APInt> optionalOrdinal = exportOp.getOrdinal()) {
ordinal = optionalOrdinal->getZExtValue();
} else {
// For executables with only one entry point, linking doesn't kick in at
// all. So the ordinal can be missing for this case.
if (!llvm::hasSingleElement(exportOps)) {
return exportOp.emitError() << "should have ordinal attribute";
}
}
entryPointOrdinals[exportOp.getSymName()] = ordinal;
}
uint64_t ordinalCount = entryPointOrdinals.size();

FlatbufferBuilder builder;
iree_hal_spirv_ExecutableDef_start_as_root(builder);

// The list of shader modules.
SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;

// Per entry-point data.
// Note that the following vectors should all be of the same size and
// element at index #i is for entry point with ordinal #i!
SmallVector<StringRef> entryPointNames;
SmallVector<uint32_t> subgroupSizes;
SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;
SmallVector<uint32_t> shaderModuleIndices;
SmallVector<iree_hal_spirv_FileLineLocDef_ref_t> sourceLocationRefs;
entryPointNames.resize(ordinalCount);
subgroupSizes.resize(ordinalCount);
shaderModuleIndices.resize(ordinalCount);

bool hasAnySubgroupSizes = false;

// Iterate over all spirv.module ops and encode them into the FlatBuffer
Expand All @@ -202,6 +230,7 @@ class VulkanSPIRVTargetBackend : public TargetBackend {
<< "expected to contain exactly one entry point";
}
spirv::EntryPointOp spvEntryPoint = *spirvEntryPoints.begin();
uint64_t ordinal = entryPointOrdinals.at(spvEntryPoint.getFn());

if (!options.dumpIntermediatesPath.empty()) {
std::string assembly;
Expand All @@ -223,31 +252,34 @@ class VulkanSPIRVTargetBackend : public TargetBackend {
}
auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
spvBinary.size());
shaderModuleIndices.push_back(shaderModuleRefs.size());
shaderModuleIndices[ordinal] = shaderModuleRefs.size();
shaderModuleRefs.push_back(
iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef));

// The IREE runtime uses ordinals instead of names. We need to attach the
// entry point name for VkShaderModuleCreateInfo.
entryPointNames.push_back(spvEntryPoint.getFn());
entryPointNames[ordinal] = spvEntryPoint.getFn();

// If there are subgroup size requests, we need to pick up too.
auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(spvEntryPoint.getFn());
auto abi = fn->getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName());
if (abi && abi.getSubgroupSize()) {
subgroupSizes.push_back(*abi.getSubgroupSize());
subgroupSizes[ordinal] = *abi.getSubgroupSize();
hasAnySubgroupSizes = true;
} else {
subgroupSizes.push_back(0);
subgroupSizes[ordinal] = 0;
}

// Optional source location information for debugging/profiling.
if (options.debugLevel >= 1) {
if (auto loc = findFirstFileLoc(spvEntryPoint.getLoc())) {
// We only ever resize to the maximum -- so all previous data will be
// kept as-is.
sourceLocationRefs.resize(ordinalCount);
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs.push_back(iree_hal_spirv_FileLineLocDef_create(
builder, filenameRef, loc->getLine()));
sourceLocationRefs[ordinal] = iree_hal_spirv_FileLineLocDef_create(
builder, filenameRef, loc->getLine());
}
};
}
Expand Down

0 comments on commit 9726ead

Please sign in to comment.