From ddd039e326d73124f25296317db44d7075f7ec05 Mon Sep 17 00:00:00 2001 From: Maurice Heumann Date: Thu, 5 Dec 2024 14:54:39 +0100 Subject: [PATCH] [InstCombine] Prevent infinite loop with two shifts The following pattern: `(C2 << X) << C1` will usually be transformed into `(C2 << C1) << X`, essentially swapping `X` and `C1`. However, this should not only done when `C1` is an immediate constant, otherwise this can lead to both constants being swapped forever This fixes #118798 --- .../Transforms/InstCombine/InstCombineShifts.cpp | 3 ++- .../Transforms/InstCombine/shl-twice-constant.ll | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 llvm/test/Transforms/InstCombine/shl-twice-constant.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 10c3ccdb2243a1..d511e79e3e48ae 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -427,7 +427,8 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - if (Constant *CUI = dyn_cast(Op1)) + Constant *CUI; + if (match(Op1, m_ImmConstant(CUI))) if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; diff --git a/llvm/test/Transforms/InstCombine/shl-twice-constant.ll b/llvm/test/Transforms/InstCombine/shl-twice-constant.ll new file mode 100644 index 00000000000000..bbdd7fa3d1c406 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/shl-twice-constant.ll @@ -0,0 +1,16 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +@c = external constant i8 +@c2 = external constant i8 + +define i64 @testfunc() { +; CHECK-LABEL: @testfunc( +; CHECK-NEXT: [[SHL1:%.*]] = shl nuw i64 1, ptrtoint (ptr @c2 to i64) +; CHECK-NEXT: [[SHL2:%.*]] = shl i64 [[SHL1]], ptrtoint (ptr @c to i64) +; CHECK-NEXT: ret i64 [[SHL2]] +; + %shl1 = shl i64 1, ptrtoint (ptr @c2 to i64) + %shl2 = shl i64 %shl1, ptrtoint (ptr @c to i64) + ret i64 %shl2 +}