Skip to content

Commit

Permalink
[llvm] support multiple save/restore points in mir
Browse files Browse the repository at this point in the history
Currently mir supports only one save and one restore point specification:

```
  savePoint:       '%bb.1'
  restorePoint:    '%bb.2'
```

This patch provide possibility to specify multiple save and multiple restore points in mir:

```
  savePoint:
    - point:           '%bb.1'
  restorePoint:
    - point:           '%bb.2'
```
while maintaining backward compatibility.
  • Loading branch information
enoskova-sc committed Feb 12, 2025
1 parent ad152f4 commit 1daee32
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 47 deletions.
74 changes: 67 additions & 7 deletions llvm/include/llvm/CodeGen/MIRYamlMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,53 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
namespace llvm {
namespace yaml {

struct SaveRestorePointEntry {
StringValue Point;

bool operator==(const SaveRestorePointEntry &Other) const {
return Point == Other.Point;
}
};

using SaveRestorePoints =
std::variant<std::vector<SaveRestorePointEntry>, StringValue>;

template <> struct PolymorphicTraits<SaveRestorePoints> {

static NodeKind getKind(const SaveRestorePoints &SRPoints) {
if (std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
return NodeKind::Sequence;
if (std::holds_alternative<StringValue>(SRPoints))
return NodeKind::Scalar;
llvm_unreachable("Unknown map value kind in SaveRestorePoints");
}

static SaveRestorePointEntry &getAsMap(SaveRestorePoints &SRPoints) {
llvm_unreachable("111");
}

static std::vector<SaveRestorePointEntry> &
getAsSequence(SaveRestorePoints &SRPoints) {
if (!std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
SRPoints = std::vector<SaveRestorePointEntry>();

return std::get<std::vector<SaveRestorePointEntry>>(SRPoints);
}

static StringValue &getAsScalar(SaveRestorePoints &SRPoints) {
if (!std::holds_alternative<StringValue>(SRPoints))
SRPoints = StringValue();

return std::get<StringValue>(SRPoints);
}
};

template <> struct MappingTraits<SaveRestorePointEntry> {
static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
YamlIO.mapRequired("point", Entry.Point);
}
};

template <> struct MappingTraits<MachineJumpTable> {
static void mapping(IO &YamlIO, MachineJumpTable &JT) {
YamlIO.mapRequired("kind", JT.Kind);
Expand All @@ -639,6 +686,14 @@ template <> struct MappingTraits<MachineJumpTable> {
}
};

} // namespace yaml
} // namespace llvm

LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SaveRestorePointEntry)

namespace llvm {
namespace yaml {

/// Serializable representation of MachineFrameInfo.
///
/// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
Expand Down Expand Up @@ -666,8 +721,8 @@ struct MachineFrameInfo {
bool HasTailCall = false;
bool IsCalleeSavedInfoValid = false;
unsigned LocalFrameSize = 0;
StringValue SavePoint;
StringValue RestorePoint;
SaveRestorePoints SavePoints;
SaveRestorePoints RestorePoints;

bool operator==(const MachineFrameInfo &Other) const {
return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
Expand All @@ -688,7 +743,8 @@ struct MachineFrameInfo {
HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
HasTailCall == Other.HasTailCall &&
LocalFrameSize == Other.LocalFrameSize &&
SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
SavePoints == Other.SavePoints &&
RestorePoints == Other.RestorePoints &&
IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
}
};
Expand Down Expand Up @@ -720,10 +776,14 @@ template <> struct MappingTraits<MachineFrameInfo> {
YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
false);
YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
YamlIO.mapOptional("savePoint", MFI.SavePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional(
"savePoint", MFI.SavePoints,
SaveRestorePoints(
StringValue())); // Don't print it out when it's empty.
YamlIO.mapOptional(
"restorePoint", MFI.RestorePoints,
SaveRestorePoints(
StringValue())); // Don't print it out when it's empty.
}
};

Expand Down
55 changes: 41 additions & 14 deletions llvm/lib/CodeGen/MIRParser/MIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class MIRParserImpl {
bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

bool initializeSaveRestorePoints(PerFunctionMIParsingState &PFS,
const yaml::SaveRestorePoints &YamlSRP,
bool IsSavePoints);

bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

Expand Down Expand Up @@ -851,18 +855,9 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
MFI.setHasTailCall(YamlMFI.HasTailCall);
MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
if (!YamlMFI.SavePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
return true;
MFI.setSavePoint(MBB);
}
if (!YamlMFI.RestorePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
return true;
MFI.setRestorePoint(MBB);
}
initializeSaveRestorePoints(PFS, YamlMFI.SavePoints, true /*IsSavePoints*/);
initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints,
false /*IsSavePoints*/);

std::vector<CalleeSavedInfo> CSIInfo;
// Initialize the fixed frame objects.
Expand Down Expand Up @@ -1077,8 +1072,40 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
return false;
}

bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineJumpTable &YamlJTI) {
bool MIRParserImpl::initializeSaveRestorePoints(
PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRP,
bool IsSavePoints) {
MachineBasicBlock *MBB = nullptr;
if (std::holds_alternative<std::vector<yaml::SaveRestorePointEntry>>(
YamlSRP)) {
const auto &VectorRepr =
std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRP);
if (VectorRepr.empty())
return false;

const auto &Entry = VectorRepr.front();
const auto &MBBSource = Entry.Point;
if (parseMBBReference(PFS, MBB, MBBSource.Value))
return true;
} else {
yaml::StringValue StringRepr = std::get<yaml::StringValue>(YamlSRP);
if (StringRepr.Value.empty() || parseMBBReference(PFS, MBB, StringRepr))
return true;
}

MachineFunction &MF = PFS.MF;
MachineFrameInfo &MFI = MF.getFrameInfo();

if (IsSavePoints)
MFI.setSavePoint(MBB);
else
MFI.setRestorePoint(MBB);

return false;
}

bool MIRParserImpl::initializeJumpTableInfo(
PerFunctionMIParsingState &PFS, const yaml::MachineJumpTable &YamlJTI) {
MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);
for (const auto &Entry : YamlJTI.Entries) {
std::vector<MachineBasicBlock *> Blocks;
Expand Down
26 changes: 18 additions & 8 deletions llvm/lib/CodeGen/MIRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class MIRPrinter {
const TargetRegisterInfo *TRI);
void convert(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
const MachineFrameInfo &MFI);
void convert(ModuleSlotTracker &MST, yaml::SaveRestorePoints &YamlSRP,
MachineBasicBlock *SaveRestorePoint);
void convert(yaml::MachineFunction &MF,
const MachineConstantPool &ConstantPool);
void convert(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
Expand Down Expand Up @@ -397,14 +399,10 @@ void MIRPrinter::convert(ModuleSlotTracker &MST,
YamlMFI.HasTailCall = MFI.hasTailCall();
YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
if (MFI.getSavePoint()) {
raw_string_ostream StrOS(YamlMFI.SavePoint.Value);
StrOS << printMBBReference(*MFI.getSavePoint());
}
if (MFI.getRestorePoint()) {
raw_string_ostream StrOS(YamlMFI.RestorePoint.Value);
StrOS << printMBBReference(*MFI.getRestorePoint());
}
if (MFI.getSavePoint())
convert(MST, YamlMFI.SavePoints, MFI.getSavePoint());
if (MFI.getRestorePoint())
convert(MST, YamlMFI.RestorePoints, MFI.getRestorePoint());
}

void MIRPrinter::convertEntryValueObjects(yaml::MachineFunction &YMF,
Expand Down Expand Up @@ -646,6 +644,18 @@ void MIRPrinter::convert(yaml::MachineFunction &MF,
}
}

void MIRPrinter::convert(ModuleSlotTracker &MST,
yaml::SaveRestorePoints &YamlSRP,
MachineBasicBlock *SRP) {
std::string Str;
yaml::SaveRestorePointEntry Entry;
raw_string_ostream StrOS(Str);
StrOS << printMBBReference(*SRP);
Entry.Point = StrOS.str();
auto &Points = std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRP);
Points.push_back(Entry);
}

void MIRPrinter::convert(ModuleSlotTracker &MST,
yaml::MachineJumpTable &YamlJTI,
const MachineJumpTableInfo &JTI) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
; RUN: llc -x=mir -simplify-mir -run-pass=shrink-wrap -o - %s | FileCheck %s
; CHECK: name: compiler_pop_stack
; CHECK: frameInfo:
; CHECK: savePoint: '%bb.1'
; CHECK: restorePoint: '%bb.7'
; CHECK: savePoint:
; CHECK-NEXT: - point: '%bb.1'
; CHECK: restorePoint:
; CHECK-NEXT: - point: '%bb.7'
; CHECK: name: compiler_pop_stack_no_memoperands
; CHECK: frameInfo:
; CHECK: savePoint: '%bb.1'
; CHECK: restorePoint: '%bb.7'
; CHECK: savePoint:
; CHECK-NEXT: - point: '%bb.1'
; CHECK: restorePoint:
; CHECK-NEXT: - point: '%bb.7'
; CHECK: name: f
; CHECK: frameInfo:
; CHECK: savePoint: '%bb.2'
; CHECK-NEXT: restorePoint: '%bb.4'
; CHECK-NEXT: stack:
; CHECK: savePoint:
; CHECK-NEXT: - point: '%bb.2'
; CHECK: restorePoint:
; CHECK-NEXT: - point: '%bb.4'
; CHECK: stack:

target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64"
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/ARM/invalidated-save-point.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
; this point. Notably, if it isn't is will be invalid and reference a
; deleted block (%bb.-1.if.end)

; CHECK: savePoint: ''
; CHECK: restorePoint: ''
; CHECK: savePoint: []
; CHECK: restorePoint: []

target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
target triple = "thumbv7"
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/MIR/Generic/frame-info.mir
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ tracksRegLiveness: true
# CHECK-NEXT: hasTailCall: false
# CHECK-NEXT: isCalleeSavedInfoValid: false
# CHECK-NEXT: localFrameSize: 0
# CHECK-NEXT: savePoint: ''
# CHECK-NEXT: restorePoint: ''
# CHECK-NEXT: savePoint: []
# CHECK-NEXT: restorePoint: []
# CHECK: body
frameInfo:
maxAlignment: 4
Expand Down
6 changes: 4 additions & 2 deletions llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ liveins:
- { reg: '$edi' }
- { reg: '$esi' }
# CHECK: frameInfo:
# CHECK: savePoint: '%bb.2'
# CHECK-NEXT: restorePoint: '%bb.2'
# CHECK: savePoint:
# CHECK-NEXT: - point: '%bb.2'
# CHECK: restorePoint:
# CHECK-NEXT: - point: '%bb.2'
# CHECK: stack
frameInfo:
maxAlignment: 4
Expand Down
6 changes: 4 additions & 2 deletions llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ frameInfo:
hasOpaqueSPAdjustment: false
hasVAStart: false
hasMustTailInVarArgFunc: false
# CHECK: savePoint: '%bb.1'
# CHECK: restorePoint: '%bb.3'
# CHECK: savePoint:
# CHECK-NEXT: - point: '%bb.1'
# CHECK: restorePoint:
# CHECK-NEXT: - point: '%bb.3'
savePoint: ''
restorePoint: ''
fixedStack:
Expand Down
8 changes: 5 additions & 3 deletions llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
# RESULT-NEXT: hasVAStart: true
# RESULT-NEXT: hasMustTailInVarArgFunc: true
# RESULT-NEXT: hasTailCall: true
# RESULT-NEXT: savePoint: '%bb.1'
# RESULT-NEXT: restorePoint: '%bb.2'
# RESULT-NEXT: savePoint:
# RESULT-NEXT: - point: '%bb.1'
# RESULT-NEXT: restorePoint:
# RESULT-NEXT: - point: '%bb.1'

# RESULT-NEXT: fixedStack:
# RESULT-NEXT: - { id: 0, offset: 56, size: 4, alignment: 8, callee-saved-register: '$sgpr44',
Expand Down Expand Up @@ -117,7 +119,7 @@ frameInfo:
hasTailCall: true
localFrameSize: 0
savePoint: '%bb.1'
restorePoint: '%bb.2'
restorePoint: '%bb.1'

fixedStack:
- { id: 0, offset: 0, size: 8, alignment: 4, isImmutable: true, isAliased: false }
Expand Down

0 comments on commit 1daee32

Please sign in to comment.