Skip to content

Commit

Permalink
[core] Add support for aarch64 calling conventions. (#2278)
Browse files Browse the repository at this point in the history
Fix calling conventions for aarch64 and add a regression test.

Requires #2268 to be merged.

Signed-off-by: Eric Schweitz <[email protected]>
Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
schweitzpgi authored and annagrin committed Oct 17, 2024
1 parent 01e0a19 commit 789061c
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 12 deletions.
56 changes: 44 additions & 12 deletions lib/Optimizer/Builder/Factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,15 @@ static Type convertToHostSideType(Type ty) {
// integers of size 8, 16, 24, 32 or 64 together, unless the float member fits
// by itself.
static bool shouldExpand(SmallVectorImpl<Type> &packedTys,
cc::StructType structTy) {
cc::StructType structTy, unsigned scaling = 8) {
if (structTy.isEmpty())
return false;
auto *ctx = structTy.getContext();
unsigned bits = 0;
const auto scaleBy = scaling - 1;
auto scaleBits = [&](unsigned size) {
if (size < 32)
size = (size + 7) & ~7u;
size = (size + scaleBy) & ~scaleBy;
if (size > 32 && size <= 64)
size = 64;
return size;
Expand Down Expand Up @@ -525,6 +526,17 @@ static bool onlyArithmeticMembers(cc::StructType structTy) {
return true;
}

// Unchecked precondition: structTy must be entirely arithmetic.
static unsigned getLargestWidth(cc::StructType structTy) {
unsigned largest = 8;
for (auto ty : structTy.getMembers()) {
auto width = ty.getIntOrFloatBitWidth();
if (width > largest)
largest = width;
}
return largest;
}

// When the kernel comes from a class, there is always a default `this` argument
// to the kernel entry function. The CUDA-Q spec doesn't allow the kernel
// object to contain data members (yet), so we can ignore the `this` pointer.
Expand All @@ -534,16 +546,31 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
SmallVector<Type> inputTys;
bool hasSRet = false;
Type resultTy;
auto i64Ty = IntegerType::get(ctx, 64);
if (funcTy.getNumResults() == 1)
if (auto strTy = dyn_cast<cc::StructType>(funcTy.getResult(0)))
if (strTy.getBitSize() != 0 &&
strTy.getBitSize() <= CommonSmallStructSize) {
SmallVector<Type, 2> packedTys;
if (shouldExpand(packedTys, strTy) || !packedTys.empty()) {
if (packedTys.size() == 1)
resultTy = packedTys[0];
else
resultTy = cc::StructType::get(ctx, packedTys);
if (isX86_64(module)) {
// X86_64: Byte addressable scaling (packed registers). Default is a
// struct.
SmallVector<Type, 2> packedTys;
if (shouldExpand(packedTys, strTy) || !packedTys.empty()) {
if (packedTys.size() == 1)
resultTy = packedTys[0];
else
resultTy = cc::StructType::get(ctx, packedTys);
}
} else if (isAArch64(module) && onlyArithmeticMembers(strTy)) {
// AARCH64: Padded registers. Default is a two-element array.
unsigned largest = getLargestWidth(strTy);
SmallVector<Type, 2> packedTys;
if (shouldExpand(packedTys, strTy, largest) || !packedTys.empty()) {
if (packedTys.size() == 1)
resultTy = packedTys[0];
else
resultTy = cc::ArrayType::get(ctx, packedTys[0], 2);
}
}
}
if (!resultTy && funcTy.getNumResults()) {
Expand All @@ -562,7 +589,6 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
}
// If this kernel is a plain old function or a static member function, we
// don't want to add a hidden `this` argument.
auto i64Ty = IntegerType::get(ctx, 64);
auto ptrTy = cc::PointerType::get(IntegerType::get(ctx, 8));
if (addThisPtr)
inputTys.push_back(ptrTy);
Expand Down Expand Up @@ -592,9 +618,15 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
if (onlyArithmeticMembers(strTy)) {
// Empirical evidence shows that on aarch64, arguments are packed
// into a single i64 or a [2 x i64] typed value based on the size
// of the struct. This is regardless of whether the value(s) are
// floating-point or not.
if (strTy.getBitSize() > 64)
// of the struct. The exception is when there are 2 elements and
// they are both float or both double.
if ((strTy.getMembers().size() == 2) &&
(strTy.getMember(0) == strTy.getMember(1)) &&
((strTy.getMember(0) == Float32Type::get(ctx)) ||
(strTy.getMember(0) == Float64Type::get(ctx))))
inputTys.push_back(
cc::ArrayType::get(ctx, strTy.getMember(0), 2));
else if (strTy.getBitSize() > 64)
inputTys.push_back(cc::ArrayType::get(ctx, i64Ty, 2));
else
inputTys.push_back(i64Ty);
Expand Down
Loading

0 comments on commit 789061c

Please sign in to comment.