Skip to content

Commit

Permalink
[WIP][AIE2P] Refactor register bank selection
Browse files Browse the repository at this point in the history
  • Loading branch information
niwinanto committed Jan 31, 2025
1 parent ce3dc2c commit 3f14d94
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 69 deletions.
124 changes: 66 additions & 58 deletions llvm/lib/Target/AIE/aie2p/AIE2PRegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ static bool isUsedAsFifoRegInIntrinsic(const MachineRegisterInfo &MRI,
}
/// \returns true if the specified intrinsic has an accumulator
/// vector as one of its operands.
static bool isAccIntrinsic(const MachineRegisterInfo &MRI,
const MachineInstr &MI, const Register &AccReg) {
static bool isUsedAsAccRegInIntrinsic(const MachineRegisterInfo &MRI,
const MachineInstr &MI,
const Register &AccReg) {
switch (cast<GIntrinsic>(MI).getIntrinsicID()) {
// All Intrinsics with accumlator destination operand
case Intrinsic::aie2p_vbroadcast_zero_acc1024:
Expand Down Expand Up @@ -670,17 +671,16 @@ static bool isAccIntrinsic(const MachineRegisterInfo &MRI,
return false;
}

bool AIE2PRegisterBankInfo::usesAccReg(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register &RegOp) const {
bool AIE2PRegisterBankInfo::isUsedAsAccRegInInstr(
const MachineInstr &MI, const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, Register Reg) const {
unsigned Op = MI.getOpcode();
switch (Op) {
default:
break;
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return isAccIntrinsic(MRI, MI, RegOp);
return isUsedAsAccRegInIntrinsic(MRI, MI, Reg);
case TargetOpcode::COPY: {
Register DstReg = MI.getOperand(0).getReg();
if (isAccReg(DstReg))
Expand All @@ -697,20 +697,19 @@ bool AIE2PRegisterBankInfo::usesAccReg(const MachineInstr &MI,
return false;
}

// Check if the instruction has RegOp as a fifo input.
// Check if the instruction has Reg as a fifo input.
// Similar to usesAccReg for Accumulators
bool AIE2PRegisterBankInfo::hasFifoInput(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register RegOp) const {
bool AIE2PRegisterBankInfo::isUsedAsFifoRegInInstr(
const MachineInstr &MI, const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, Register Reg) const {
auto *RI = static_cast<const AIEBaseRegisterInfo *>(
MI.getParent()->getParent()->getSubtarget().getRegisterInfo());
switch (MI.getOpcode()) {
default:
break;
case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
return isUsedAsFifoRegInIntrinsic(MRI, MI, RegOp);
return isUsedAsFifoRegInIntrinsic(MRI, MI, Reg);
case TargetOpcode::COPY: {
Register DstReg = MI.getOperand(0).getReg();
if (RI->isFifoPhysReg(DstReg))
Expand All @@ -727,21 +726,46 @@ bool AIE2PRegisterBankInfo::hasFifoInput(const MachineInstr &MI,
return false;
}

bool AIE2PRegisterBankInfo::isUseAccInsn(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register &RegOp) const {
bool AIE2PRegisterBankInfo::isUsedAsAccReg(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
Register Reg) const {
return any_of(MRI.use_nodbg_instructions(Reg),
[&](const MachineInstr &UseMI) {
return isUsedAsAccRegInInstr(UseMI, MRI, TRI, Reg);
});
}

bool AIE2PRegisterBankInfo::isUsedAsFifoReg(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
Register Reg) const {
return any_of(MRI.use_nodbg_instructions(Reg),
[&](const MachineInstr &UseMI) {
return isUsedAsFifoRegInInstr(UseMI, MRI, TRI, Reg);
});
}

const RegisterBank *AIE2PRegisterBankInfo::getPreferredRegBankForVectorTy(
const MachineRegisterInfo &MRI, const TargetRegisterInfo &TRI,
Register Reg) const {
auto RegTy = MRI.getType(Reg);
auto RegSize = RegTy.getSizeInBits();
if (RegSize == 256)
return &getRegBank(AIE2P::VRegBankID);
if (RegSize == 2048)
return &getRegBank(AIE2P::AccRegBankID);

// Helper function to trace through COPY and G_BITCAST
auto traceToActualReg = [&](Register &Reg) {
for (auto &UseMI : MRI.use_nodbg_instructions(Reg)) {
MachineInstr *ConvUseMI = &UseMI;
auto TraceToActualReg = [&](Register &SkipReg) {
if (!MRI.use_empty(SkipReg)) {
MachineInstr *ConvUseMI = &*MRI.use_instr_nodbg_begin(SkipReg);
unsigned ConvUseOpc = ConvUseMI->getOpcode();
// skip copies
while (ConvUseOpc == TargetOpcode::G_BITCAST ||
ConvUseOpc == TargetOpcode::COPY) {
Register DefReg = ConvUseMI->getOperand(0).getReg();
if (DefReg.isPhysical())
break;
Reg = DefReg;
SkipReg = DefReg;
if (MRI.use_empty(DefReg))
break;
ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg);
Expand All @@ -750,35 +774,13 @@ bool AIE2PRegisterBankInfo::isUseAccInsn(const MachineRegisterInfo &MRI,
}
};

// Create a non-const copy of RegOp for tracing
Register TracedRegOp = RegOp;

// Trace RegOp to its actual register, updating it if necessary
traceToActualReg(TracedRegOp);

// Now check for accumulator-usage using the updated TracedRegOp
return any_of(MRI.use_nodbg_instructions(TracedRegOp),
[&](const MachineInstr &UseMI) {
// Check if this instruction uses the accumulator register
return (MRI.getType(TracedRegOp).getSizeInBits() == 2048) ||
usesAccReg(UseMI, MRI, TRI, TracedRegOp);
});
}

// Check if RegOp is used as a fifo register.
bool AIE2PRegisterBankInfo::isUseFifoInsn(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register RegOp) const {
// TODO: Trace RegOp to its actual definition ignoring COPY and G_BITCAST as
// we do for accumulator bank.
// Use getDefIgnoringCopiesAndBitcasts and refactor the code between
// accumulator and fifo handling.

return any_of(MRI.use_nodbg_instructions(RegOp),
[&](const MachineInstr &UseMI) {
// Check if this instruction uses the accumulator register
return hasFifoInput(UseMI, MRI, TRI, RegOp);
});
// Trace Reg to its actual register, updating it if necessary
TraceToActualReg(Reg);
if (isUsedAsAccReg(MRI, TRI, Reg))
return &getRegBank(AIE2P::AccRegBankID);
if (isUsedAsFifoReg(MRI, TRI, Reg))
return &getRegBank(AIE2P::FifoRegBankID);
return &getRegBank(AIE2P::VRegBankID);
}

const RegisterBankInfo::InstructionMapping &
Expand Down Expand Up @@ -808,7 +810,8 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
auto *RB = getRegBank(DstReg, MRI, TRI);
if (RB == &AIE2P::AccRegBank)
return AIEBaseRegisterBankInfo::getInstrMapping(MI);
if (isUseAccInsn(MRI, TRI, DstReg)) {
if (&AIE2P::AccRegBank ==
getPreferredRegBankForVectorTy(MRI, TRI, DstReg)) {
OpRegBankIdx[0] = getAccPartialMappingIdx(DstType);
for (unsigned Idx = 2; Idx < NumOperands; ++Idx) {
LLT Type = MRI.getType(MI.getOperand(Idx).getReg());
Expand All @@ -826,7 +829,7 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
LLT Ty = MRI.getType(MO.getReg());
if (!Ty.isValid())
continue;
if (isAccIntrinsic(MRI, MI, MI.getOperand(Idx).getReg())) {
if (isUsedAsAccRegInIntrinsic(MRI, MI, MI.getOperand(Idx).getReg())) {
LLT Type = MRI.getType(MI.getOperand(Idx).getReg());
OpRegBankIdx[Idx] = getAccPartialMappingIdx(Type);
continue;
Expand All @@ -849,7 +852,8 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
auto *RB = getRegBank(DstReg, MRI, TRI);
if (RB == &AIE2P::AccRegBank)
return AIEBaseRegisterBankInfo::getInstrMapping(MI);
if (isUseAccInsn(MRI, TRI, DstReg)) {
if (&AIE2P::AccRegBank ==
getPreferredRegBankForVectorTy(MRI, TRI, DstReg)) {
OpRegBankIdx[0] = getAccPartialMappingIdx(Type);
return AIEBaseRegisterBankInfo::getInstrMappingFinal(MI, Cost, OpSize,
OpRegBankIdx);
Expand All @@ -864,7 +868,7 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
LLT SrcType = MRI.getType(SrcReg);
// Check if we already know the register bank.
auto *RB = getRegBank(SrcReg, MRI, TRI);
if (isUseAccInsn(MRI, TRI, DstReg))
if (&AIE2P::AccRegBank == getPreferredRegBankForVectorTy(MRI, TRI, DstReg))
OpRegBankIdx[0] = getAccPartialMappingIdx(DstType);
else
OpRegBankIdx[0] = getPartialMappingIdx(DstType);
Expand Down Expand Up @@ -924,7 +928,8 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
OpRegBankIdx);
}

if (isUseAccInsn(MRI, TRI, DstReg)) {
if (&AIE2P::AccRegBank ==
getPreferredRegBankForVectorTy(MRI, TRI, DstReg)) {
OpRegBankIdx[0] = getAccPartialMappingIdx(DstType);
for (unsigned Idx = 1; Idx < NumOperands; ++Idx) {
LLT Type = MRI.getType(MI.getOperand(Idx).getReg());
Expand All @@ -945,9 +950,11 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
auto *RB = getRegBank(UseCandidate, MRI, TRI);
if (RB == &AIE2P::AccRegBank || RB == &AIE2P::FifoRegBank)
return AIEBaseRegisterBankInfo::getInstrMapping(MI);
if (isUseAccInsn(MRI, TRI, UseCandidate))
const auto *PreferredRB =
getPreferredRegBankForVectorTy(MRI, TRI, UseCandidate);
if (&AIE2P::AccRegBank == PreferredRB)
isAccRegMapping = true;
if (isUseFifoInsn(MRI, TRI, UseCandidate))
if (&AIE2P::FifoRegBank == PreferredRB)
isFifoPhysRegMapping = true;
// size of accu and fifo vector on aie2p >= 512.
MachineMemOperand *MMO = *MI.memoperands_begin();
Expand All @@ -963,9 +970,10 @@ AIE2PRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
UseCandidate = DefMI->getOperand(0).getReg();
Type = MRI.getType(MI.getOperand(0).getReg());
}
if (isUseAccInsn(MRI, TRI, UseCandidate))
PreferredRB = getPreferredRegBankForVectorTy(MRI, TRI, UseCandidate);
if (&AIE2P::AccRegBank == PreferredRB)
isAccRegMapping = true;
if (isUseFifoInsn(MRI, TRI, UseCandidate))
if (&AIE2P::FifoRegBank == PreferredRB)
isFifoPhysRegMapping = true;
if (isAccRegMapping) {
OpRegBankIdx[0] = getAccPartialMappingIdx(Type);
Expand Down
29 changes: 18 additions & 11 deletions llvm/lib/Target/AIE/aie2p/AIE2PRegisterBankInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#define LLVM_LIB_TARGET_AIE2P_AIE2PREGISTERBANKINFO_H

#include "AIEBaseRegisterBankInfo.h"
#include "llvm/CodeGen/RegisterBank.h"
#include <optional>

#define GET_REGBANK_DECLARATIONS
#include "AIE2PGenRegisterBank.inc"
Expand Down Expand Up @@ -77,17 +79,22 @@ class AIE2PRegisterBankInfo final : public AIE2PGenRegisterBankInfo {
getInstrMapping(const MachineInstr &MI) const override;
const RegisterBank &getRegBankFromRegClass(const TargetRegisterClass &RC,
LLT) const override;
bool usesAccReg(const MachineInstr &MI, const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, const Register &AccReg) const;
bool isUseAccInsn(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register &AccReg) const;
bool hasFifoInput(const MachineInstr &MI, const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register FifoReg) const;
bool isUseFifoInsn(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const Register FifoReg) const;
bool isUsedAsAccRegInInstr(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
Register AccReg) const;
bool isUsedAsAccReg(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, Register AccReg) const;
bool isUsedAsFifoRegInInstr(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
Register FifoReg) const;
bool isUsedAsFifoReg(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, Register FifoReg) const;
const RegisterBank *
getPreferredRegBankForVectorTy(const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
Register Reg) const;
};
} // end namespace llvm
#endif

0 comments on commit 3f14d94

Please sign in to comment.