Skip to content

Commit

Permalink
[AIE2][AIE2P] Legalize G_SELECT for 512 bits only.
Browse files Browse the repository at this point in the history
  • Loading branch information
SagarMaheshwari99 committed Jan 31, 2025
1 parent 9b3ffdf commit b1b39ab
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 392 deletions.
34 changes: 30 additions & 4 deletions llvm/lib/Target/AIE/AIE2LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023-2024 Advanced Micro Devices, Inc. or its affiliates
// (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its affiliates
//
//===----------------------------------------------------------------------===//
/// \file
Expand Down Expand Up @@ -73,6 +73,13 @@ static LegalityPredicate isValidVectorAIE2(const unsigned TypeIdx) {
};
}

static LegalityPredicate vectorSmallerThan(unsigned TypeIdx, unsigned Size) {
return [=](const LegalityQuery &Query) {
const LLT QueryTy = Query.Types[TypeIdx];
return QueryTy.isVector() && QueryTy.getSizeInBits() < Size;
};
}

LegalityPredicate
negatePredicate(const std::function<bool(const LegalityQuery &)> &Func) {
return [=](const LegalityQuery &Query) { return !Func(Query); };
Expand Down Expand Up @@ -236,10 +243,24 @@ AIE2LegalizerInfo::AIE2LegalizerInfo(const AIE2Subtarget &ST) : AIEHelper(ST) {

getActionDefinitionsBuilder(G_SELECT)
.legalFor({{S32, S32}, {P0, S32}})
.clampScalar(1, S32, S32)
// AIE2 ISA supports only 512-bit vector select
.legalFor({V16S32, V32S16, V64S8})
// For scalar types >= 256, bitcast to a vector type and use existing
// selection patterns
.bitcastIf(
[=](const LegalityQuery &Query) {
const LLT &ResTy = Query.Types[0];
return ResTy.isScalar() && ResTy.getSizeInBits() >= 256;
},
[=](const LegalityQuery &Query) {
const LLT Ty = Query.Types[0];
const unsigned Size = Ty.getSizeInBits();
return std::pair(0, LLT::fixed_vector(Size / 32, LLT::scalar(32)));
})
.widenScalarToNextPow2(0)
.clampScalar(0, S32, S32)
.clampScalar(1, S32, S32)
.legalFor(AIE2VectorTypes)
.customIf(vectorSmallerThan(0, 512))
// We support G_SELECT only on the vector register bank
// Mapping the G_SELECT operands to the vector register bank
// during register bank selection introduces the proper cross-bank
Expand All @@ -248,7 +269,10 @@ AIE2LegalizerInfo::AIE2LegalizerInfo(const AIE2Subtarget &ST) : AIEHelper(ST) {
// type patterns in C++. Introducing bitcasts during legalization allows
// to re-use the existing code for register bank selection and ISEL
// patterns.
.bitcastIf(typeInSet(0, AIE2AccumulatorTypes), bitcastAccToVectorType(0));
.bitcastIf(typeInSet(0, AIE2AccumulatorTypes), bitcastAccToVectorType(0))
.clampMaxNumElements(0, S8, 64)
.clampMaxNumElements(0, S16, 32)
.clampMaxNumElements(0, S32, 16);

getActionDefinitionsBuilder({G_ADD, G_SUB})
.legalFor({S32})
Expand Down Expand Up @@ -546,6 +570,8 @@ bool AIE2LegalizerInfo::legalizeCustom(
return AIEHelper.legalizeG_SEXT_INREG(Helper, MI);
case TargetOpcode::G_BITCAST:
return AIEHelper.legalizeG_BITCAST(Helper, MI);
case TargetOpcode::G_SELECT:
return AIEHelper.legalizeG_SELECT(Helper, MI, /* MaxBitSize */ 512);
}

llvm_unreachable("Un-expected custom legalization");
Expand Down
59 changes: 39 additions & 20 deletions llvm/lib/Target/AIE/AIELegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023-2024 Advanced Micro Devices, Inc. or its affiliates
// (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its affiliates
//
//===----------------------------------------------------------------------===//
/// \file
Expand Down Expand Up @@ -1186,38 +1186,57 @@ bool AIELegalizerHelper::legalizeLoopDecrement(LegalizerHelper &Helper,
return true;
}

// Legalize 2048-bit G_SELECT
// Legalize < MaxBitSize-bit G_SELECT
// Expand the source vectors to MaxBitSize-bits by padding it with undefs.
bool AIELegalizerHelper::legalizeG_SELECT(LegalizerHelper &Helper,
MachineInstr &MI) const {
MachineInstr &MI,
const unsigned MaxBitSize) const {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

const Register DstReg = MI.getOperand(0).getReg();
const LLT DstTy = MRI.getType(DstReg);
const unsigned DstVecSize = DstTy.getSizeInBits();

assert(DstTy.isVector() && DstVecSize < MaxBitSize &&
"Expected to legalize < MaxBitSize-bit vector G_SELECT");
assert(!(MaxBitSize % DstVecSize) &&
"Vector size should be a factor of MaxBitSize");

const Register SrcReg0 = MI.getOperand(1).getReg(); // Scalar
const Register SrcReg1 = MI.getOperand(2).getReg();
const Register SrcReg2 = MI.getOperand(3).getReg();
const LLT DstTy = MRI.getType(DstReg);
assert(DstTy.isVector() && DstTy.getSizeInBits() == 2048 &&
"Expected to legalize 2048-bit vector G_SELECT");
const LLT DstVecEltTy = DstTy.getElementType();
const unsigned ElTySize = DstVecEltTy.getSizeInBits();
const LLT ACC1024 = LLT::fixed_vector(1024 / ElTySize, ElTySize);

const Register Dst0LoReg = MRI.createGenericVirtualRegister(ACC1024);
const Register Dst0HiReg = MRI.createGenericVirtualRegister(ACC1024);
const Register Dst1LoReg = MRI.createGenericVirtualRegister(ACC1024);
const Register Dst1HiReg = MRI.createGenericVirtualRegister(ACC1024);
const LLT NewVecTy =
LLT::fixed_vector(MaxBitSize / DstTy.getElementType().getSizeInBits(),
DstTy.getElementType());

MIRBuilder.buildUnmerge({Dst0LoReg, Dst0HiReg}, SrcReg1);
MIRBuilder.buildUnmerge({Dst1LoReg, Dst1HiReg}, SrcReg2);
const Register UndefReg = MRI.createGenericVirtualRegister(DstTy);
MIRBuilder.buildUndef(UndefReg);

const unsigned NumPadElts = (MaxBitSize / DstVecSize) - 1;
auto buildMergeInstr = [&](const Register SrcReg) -> Register {
SmallVector<Register, 4> Regs;
Regs.push_back(SrcReg);
for (unsigned I = 0; I < NumPadElts; I++)
Regs.push_back(UndefReg);
const Register NewSrcReg = MRI.createGenericVirtualRegister(NewVecTy);
MIRBuilder.buildMergeLikeInstr(NewSrcReg, Regs);
return NewSrcReg;
};

const Register DstRegLoSelect = MRI.createGenericVirtualRegister(ACC1024);
const Register DstRegHiSelect = MRI.createGenericVirtualRegister(ACC1024);
const Register NewSrcReg1 = buildMergeInstr(SrcReg1);
const Register NewSrcReg2 = buildMergeInstr(SrcReg2);

MIRBuilder.buildSelect(DstRegLoSelect, SrcReg0, Dst0LoReg, Dst1LoReg);
MIRBuilder.buildSelect(DstRegHiSelect, SrcReg0, Dst0HiReg, Dst1HiReg);
const Register NewDstReg = MRI.createGenericVirtualRegister(NewVecTy);
MIRBuilder.buildInstr(MI.getOpcode(), {NewDstReg},
{SrcReg0, NewSrcReg1, NewSrcReg2}, MI.getFlags());

MIRBuilder.buildConcatVectors({DstReg}, {DstRegLoSelect, DstRegHiSelect});
SmallVector<Register, 4> Regs;
Regs.push_back(DstReg);
for (unsigned I = 0; I < NumPadElts; ++I)
Regs.push_back(MRI.createGenericVirtualRegister(DstTy));
MIRBuilder.buildUnmerge(Regs, NewDstReg);

MI.eraseFromParent();
return true;
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AIE/AIELegalizerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023-2024 Advanced Micro Devices, Inc. or its affiliates
// (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its affiliates
//
//===----------------------------------------------------------------------===//
/// \file
Expand Down Expand Up @@ -56,7 +56,8 @@ class AIELegalizerHelper {
bool legalizeG_FPEXT(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeG_FABS(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeG_FADDSUB(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeG_SELECT(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeG_SELECT(LegalizerHelper &Helper, MachineInstr &MI,
const unsigned MaxBitSize = 512) const;
bool legalizeG_BITCAST(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeLoopDecrement(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeG_CONCAT_VECTORS(LegalizerHelper &Helper,
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,14 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
getActionDefinitionsBuilder(G_SELECT)
.legalFor({{S32, S32}, {P0, S32}})
.clampScalar(1, S32, S32)
.legalFor(AIE2PVectorTypes)
// AIE2P ISA supports only 512-bit vector select
.legalFor({V16S32, V32S16, V64S8})
// For scalar types >= 256, bitcast to a vector type and use existing
// selection patterns
.bitcastIf(
[=](const LegalityQuery &Query) {
const LLT &ResTy = Query.Types[0];
if (ResTy.isVector())
return false;
return ResTy.getSizeInBits() >= 256;
return ResTy.isScalar() && ResTy.getSizeInBits() >= 256;
},
[=](const LegalityQuery &Query) {
const LLT Ty = Query.Types[0];
Expand All @@ -289,6 +288,7 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
})
.widenScalarToNextPow2(0)
.clampScalar(0, S32, S32)
.customIf(vectorSmallerThan(0, 512))
// We support G_SELECT only on the vector register bank
// Mapping the G_SELECT operands to the vector register bank
// during register bank selection introduces the proper cross-bank
Expand All @@ -299,7 +299,9 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
// patterns.
.bitcastIf(typeInSet(0, {AccV4S64, AccV8S64, AccV16S64}),
bitcastAccToVectorType(0))
.customFor({{AccV64S32, S32}});
.clampMaxNumElements(0, S8, 64)
.clampMaxNumElements(0, S16, 32)
.clampMaxNumElements(0, S32, 16);

getActionDefinitionsBuilder({G_ADD, G_SUB, G_XOR})
.legalFor({S32})
Expand Down Expand Up @@ -667,7 +669,7 @@ bool AIE2PLegalizerInfo::legalizeCustom(
case TargetOpcode::G_SEXT_INREG:
return AIEHelper.legalizeG_SEXT_INREG(Helper, MI);
case TargetOpcode::G_SELECT:
return AIEHelper.legalizeG_SELECT(Helper, MI);
return AIEHelper.legalizeG_SELECT(Helper, MI, /* MaxBitSize */ 512);
case TargetOpcode::G_CONCAT_VECTORS:
return AIEHelper.legalizeG_CONCAT_VECTORS(Helper, MI);
case TargetOpcode::G_BITCAST:
Expand Down
Loading

0 comments on commit b1b39ab

Please sign in to comment.