Skip to content

Commit

Permalink
Allow optimizing mask conversions on x64 as well (dotnet#110195)
Browse files Browse the repository at this point in the history
* Allow optimizing mask conversions on x64 as well

* Ensure the right operand is accessed on xarch

* Minimally handle CndSel as part of optimizing mask conversions

* Add some additional comments and clean up the logic a bit

* Apply formatting patch
  • Loading branch information
tannergooding authored and eduardo-vp committed Dec 4, 2024
1 parent 73e6b6b commit 06a071b
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 46 deletions.
52 changes: 49 additions & 3 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26786,7 +26786,55 @@ bool GenTree::OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const
{
if (OperIsHWIntrinsic())
{
return AsHWIntrinsic()->GetHWIntrinsicId() == intrinsicId;
return AsHWIntrinsic()->OperIsHWIntrinsic(intrinsicId);
}
return false;
}

//------------------------------------------------------------------------
// OperIsConvertMaskToVector: Is this a ConvertMaskToVector hwintrinsic
//
// Return Value:
// true if the node is a ConvertMaskToVector hwintrinsic
// otherwise; false
//
bool GenTree::OperIsConvertMaskToVector() const
{
if (OperIsHWIntrinsic())
{
return AsHWIntrinsic()->OperIsConvertMaskToVector();
}
return false;
}

//------------------------------------------------------------------------
// OperIsConvertVectorToMask: Is this a ConvertVectorToMask hwintrinsic
//
// Return Value:
// true if the node is a ConvertVectorToMask hwintrinsic
// otherwise; false
//
bool GenTree::OperIsConvertVectorToMask() const
{
if (OperIsHWIntrinsic())
{
return AsHWIntrinsic()->OperIsConvertVectorToMask();
}
return false;
}

//------------------------------------------------------------------------
// OperIsVectorConditionalSelect: Is this a vector ConditionalSelect hwintrinsic
//
// Return Value:
// true if the node is a vector ConditionalSelect hwintrinsic
// otherwise; false
//
bool GenTree::OperIsVectorConditionalSelect() const
{
if (OperIsHWIntrinsic())
{
return AsHWIntrinsic()->OperIsVectorConditionalSelect();
}
return false;
}
Expand Down Expand Up @@ -30678,8 +30726,6 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
#if defined(FEATURE_HW_INTRINSICS)
GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
{
assert(tree->OperIsHWIntrinsic());

if (!opts.Tier0OptimizationEnabled())
{
return tree;
Expand Down
68 changes: 42 additions & 26 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -1665,32 +1665,9 @@ struct GenTree
}

bool OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const;

bool OperIsConvertMaskToVector() const
{
#if defined(FEATURE_HW_INTRINSICS)
#if defined(TARGET_XARCH)
return OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector);
#elif defined(TARGET_ARM64)
return OperIsHWIntrinsic(NI_Sve_ConvertMaskToVector);
#endif // !TARGET_XARCH && !TARGET_ARM64
#else
return false;
#endif // FEATURE_HW_INTRINSICS
}

bool OperIsConvertVectorToMask() const
{
#if defined(FEATURE_HW_INTRINSICS)
#if defined(TARGET_XARCH)
return OperIsHWIntrinsic(NI_EVEX_ConvertVectorToMask);
#elif defined(TARGET_ARM64)
return OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask);
#endif // !TARGET_XARCH && !TARGET_ARM64
#else
return false;
#endif // FEATURE_HW_INTRINSICS
}
bool OperIsConvertMaskToVector() const;
bool OperIsConvertVectorToMask() const;
bool OperIsVectorConditionalSelect() const;

// This is here for cleaner GT_LONG #ifdefs.
static bool OperIsLong(genTreeOps gtOper)
Expand Down Expand Up @@ -6583,6 +6560,45 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
bool OperIsBitwiseHWIntrinsic() const;
bool OperIsEmbRoundingEnabled() const;

bool OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const
{
return GetHWIntrinsicId() == intrinsicId;
}

bool OperIsConvertMaskToVector() const
{
#if defined(TARGET_XARCH)
return OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector);
#elif defined(TARGET_ARM64)
return OperIsHWIntrinsic(NI_Sve_ConvertMaskToVector);
#else
return false;
#endif
}

bool OperIsConvertVectorToMask() const
{
#if defined(TARGET_XARCH)
return OperIsHWIntrinsic(NI_EVEX_ConvertVectorToMask);
#elif defined(TARGET_ARM64)
return OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask);
#else
return false;
#endif
}

bool OperIsVectorConditionalSelect() const
{
#if defined(TARGET_XARCH)
return OperIsHWIntrinsic(NI_Vector128_ConditionalSelect) || OperIsHWIntrinsic(NI_Vector256_ConditionalSelect) ||
OperIsHWIntrinsic(NI_Vector512_ConditionalSelect);
#elif defined(TARGET_ARM64)
return OperIsHWIntrinsic(NI_AdvSimd_BitwiseSelect) || OperIsHWIntrinsic(NI_Sve_ConditionalSelect);
#else
return false;
#endif
}

bool OperRequiresAsgFlag() const;
bool OperRequiresCallFlag() const;
bool OperRequiresGlobRefFlag() const;
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2766,12 +2766,12 @@ void LinearScan::buildIntervals()
{
calleeSaveCount = CNT_CALLEE_ENREG;
}
#if (defined(TARGET_XARCH) || defined(TARGET_ARM64)) && defined(FEATURE_SIMD)
#if defined(FEATURE_MASKED_HW_INTRINSICS)
else if (varTypeUsesMaskReg(interval->registerType))
{
calleeSaveCount = CNT_CALLEE_SAVED_MASK;
}
#endif // (TARGET_XARCH || TARGET_ARM64) && FEATURE_SIMD
#endif // FEATURE_MASKED_HW_INTRINSICS
else
{
assert(varTypeUsesFloatReg(interval->registerType));
Expand Down
71 changes: 56 additions & 15 deletions src/coreclr/jit/optimizemaskconversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "jitpch.h"

#if defined(TARGET_ARM64)
#if defined(FEATURE_MASKED_HW_INTRINSICS)

struct MaskConversionsWeight
{
Expand All @@ -19,8 +19,13 @@ struct MaskConversionsWeight
// Conversion of mask to vector is one instruction.
static constexpr const weight_t costOfConvertMaskToVector = 1.0;

#if defined(TARGET_ARM64)
// Conversion of vector to mask is two instructions.
static constexpr const weight_t costOfConvertVectorToMask = 2.0;
#else
// Conversion of vector to mask is one instructions.
static constexpr const weight_t costOfConvertVectorToMask = 1.0;
#endif

// The simd types of the Lcl Store after conversion to vector.
CorInfoType simdBaseJitType = CORINFO_TYPE_UNDEF;
Expand Down Expand Up @@ -136,6 +141,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
switch ((*use)->OperGet())
{
case GT_STORE_LCL_VAR:
{
isLocalStore = true;

// Look for:
Expand All @@ -147,19 +153,48 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
hasConversion = true;
}
break;
}

case GT_LCL_VAR:
{
isLocalUse = true;

// Look for:
// user:ConvertVectorToMask(use:LCL_VAR(x)))
// user: ConvertVectorToMask(use:LCL_VAR(x)))
// -or-
// user: ConditionalSelect(use:LCL_VAR(x), y, z)

if (user->OperIsConvertVectorToMask())
if (user->OperIsHWIntrinsic())
{
convertOp = user->AsHWIntrinsic();
hasConversion = true;
GenTreeHWIntrinsic* hwintrin = user->AsHWIntrinsic();
NamedIntrinsic ni = hwintrin->GetHWIntrinsicId();

if (hwintrin->OperIsConvertVectorToMask())
{
convertOp = user->AsHWIntrinsic();
hasConversion = true;
}
else if (hwintrin->OperIsVectorConditionalSelect())
{
// We don't actually have a convert here, but we do have a case where
// the mask is being used in a ConditionalSelect and therefore can be
// consumed directly as a mask. While the IR shows TYP_SIMD, it gets
// handled in lowering as part of the general embedded-mask support.

// We notably don't check that op2->isEmbeddedMaskingCompatibleHWIntrinsic()
// because we can still consume the mask directly in such cases. We'll just
// emit `vblendmps zmm1 {k1}, zmm2, zmm3` instead of containing the CndSel
// as part of something like `vaddps zmm1 {k1}, zmm2, zmm3`

if (hwintrin->Op(1) == (*use))
{
convertOp = user->AsHWIntrinsic();
hasConversion = true;
}
}
}
break;
}

default:
break;
Expand Down Expand Up @@ -254,6 +289,12 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions

Compiler::fgWalkResult PostOrderVisit(GenTree** use, GenTree* user)
{
#if defined(TARGET_ARM64)
static constexpr const int ConvertVectorToMaskValueOp = 2;
#else
static constexpr const int ConvertVectorToMaskValueOp = 1;
#endif

GenTreeLclVarCommon* lclOp = nullptr;
bool isLocalStore = false;
bool isLocalUse = false;
Expand All @@ -276,11 +317,12 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
isLocalStore = true;
addConversion = true;
}
else if ((*use)->OperIsConvertVectorToMask() && (*use)->AsHWIntrinsic()->Op(2)->OperIs(GT_LCL_VAR))
else if ((*use)->OperIsConvertVectorToMask() &&
(*use)->AsHWIntrinsic()->Op(ConvertVectorToMaskValueOp)->OperIs(GT_LCL_VAR))
{
// Found
// user(use:ConvertVectorToMask(LCL_VAR(x)))
lclOp = (*use)->AsHWIntrinsic()->Op(2)->AsLclVarCommon();
lclOp = (*use)->AsHWIntrinsic()->Op(ConvertVectorToMaskValueOp)->AsLclVarCommon();
isLocalUse = true;
removeConversion = true;
}
Expand Down Expand Up @@ -393,7 +435,7 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
MaskConversionsWeightTable* weightsTable;
};

#endif // TARGET_ARM64
#endif // FEATURE_MASKED_HW_INTRINSICS

//------------------------------------------------------------------------
// fgOptimizeMaskConversions: Allow locals to be of Mask type
Expand Down Expand Up @@ -445,7 +487,7 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
//
PhaseStatus Compiler::fgOptimizeMaskConversions()
{
#if defined(TARGET_ARM64)
#if defined(FEATURE_MASKED_HW_INTRINSICS)

if (opts.OptimizationDisabled())
{
Expand Down Expand Up @@ -476,10 +518,10 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
{
for (Statement* const stmt : block->Statements())
{
// Only check statements where there is a local of type TYP_SIMD16/TYP_MASK.
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
{
if (lcl->TypeIs(TYP_SIMD16, TYP_MASK))
if (varTypeIsSIMDOrMask(lcl))
{
// Parse the entire statement.
MaskConversionsCheckVisitor ev(this, block->getBBWeight(this), &weightsTable);
Expand All @@ -504,10 +546,10 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
{
for (Statement* const stmt : block->Statements())
{
// Only check statements where there is a local of type TYP_SIMD16/TYP_MASK.
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
{
if (lcl->TypeIs(TYP_SIMD16, TYP_MASK))
if (varTypeIsSIMDOrMask(lcl))
{
// Parse the entire statement.
MaskConversionsUpdateVisitor ev(this, stmt, &weightsTable);
Expand All @@ -524,8 +566,7 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
}

return PhaseStatus::MODIFIED_EVERYTHING;

#else
return PhaseStatus::MODIFIED_NOTHING;
#endif // TARGET_ARM64
#endif // FEATURE_MASKED_HW_INTRINSICS
}

0 comments on commit 06a071b

Please sign in to comment.