[InstCombine] Add constant combines for (urem/srem (mul X, Y), (mul X, Z))

We can handle the following cases + some `nsw`/`nuw` flags:

`(srem (mul X, Y), (mul X, Z))`
    [If `srem(Y, Z) == 0`]
        -> 0
            - https://alive2.llvm.org/ce/z/PW4XZ-
    [If `srem(Y, Z) == Y`]
        -> `(mul nuw nsw X, Y)`
            - https://alive2.llvm.org/ce/z/DQe9Ek
        -> `(mul nsw X, Y)`
            - https://alive2.llvm.org/ce/z/Nr_MdH

    [If `Y`/`Z` are constant]
        -> `(mul/shl nuw nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/ccTFj2
            - https://alive2.llvm.org/ce/z/i_UQ5A
        -> `(mul/shl nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/mQKc63
            - https://alive2.llvm.org/ce/z/uERkKH

`(urem (mul X, Y), (mul X, Z))`
    [If `urem(Y, Z) == 0`]
        -> 0
            - https://alive2.llvm.org/ce/z/LL7UVR
    [If `srem(Y, Z) == Y`]
        -> `(mul nuw nsw X, Y)`
            - https://alive2.llvm.org/ce/z/9Kgs_i
        -> `(mul nuw X, Y)`
            - https://alive2.llvm.org/ce/z/ow9i8u

    [If `Y`/`Z` are constant]
        -> `(mul nuw nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/mNnQqJ
            - https://alive2.llvm.org/ce/z/Bj_DR-
            - https://alive2.llvm.org/ce/z/X6ZEtQ
        -> `(mul nuw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/SJYtUV

The rationale for doing this all in `InstCombine` rather than handling
the constant `mul` cases in `InstSimplify` is we often create a new
instruction because we are able to deduce more `nsw`/`nuw` flags than
the original instruction had.

Reviewed By: MattDevereau, sdesmalen

Differential Revision: https://reviews.llvm.org/D143014
This commit is contained in:
Noah Goldstein 2023-03-15 16:53:13 -05:00
parent 994cd986f1
commit aba71f37d0
2 changed files with 68 additions and 28 deletions

View File

@ -1698,6 +1698,63 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return nullptr;
}
// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z))
static Instruction *simplifyIRemMulShl(BinaryOperator &I,
InstCombinerImpl &IC) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X;
const APInt *Y, *Z;
if (!(match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) &&
!(match(Op0, m_Mul(m_APInt(Y), m_Value(X))) &&
match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))))
return nullptr;
bool IsSRem = I.getOpcode() == Instruction::SRem;
OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0);
// TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >=
// Z or Z >= Y.
bool BO0HasNSW = BO0->hasNoSignedWrap();
bool BO0HasNUW = BO0->hasNoUnsignedWrap();
bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW;
APInt RemYZ = IsSRem ? Y->srem(*Z) : Y->urem(*Z);
// (rem (mul nuw/nsw X, Y), (mul X, Z))
// if (rem Y, Z) == 0
// -> 0
if (RemYZ.isZero() && BO0NoWrap)
return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
bool BO1HasNSW = BO1->hasNoSignedWrap();
bool BO1HasNUW = BO1->hasNoUnsignedWrap();
bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW;
// (rem (mul X, Y), (mul nuw/nsw X, Z))
// if (rem Y, Z) == Y
// -> (mul nuw/nsw X, Y)
if (RemYZ == *Y && BO1NoWrap) {
BinaryOperator *BO =
BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), *Y));
// Copy any overflow flags from Op0.
BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
return BO;
}
// (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
// if Y >= Z
// -> (mul {nuw} nsw X, (rem Y, Z))
if (Y->uge(*Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
BinaryOperator *BO =
BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ));
BO->setHasNoSignedWrap();
BO->setHasNoUnsignedWrap(BO0HasNUW);
return BO;
}
return nullptr;
}
/// This function implements the transforms common to both integer remainder
/// instructions (urem and srem). It is called by the visitors to those integer
/// remainder instructions.
@ -1750,6 +1807,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
}
}
if (Instruction *R = simplifyIRemMulShl(I, *this))
return R;
return nullptr;
}

View File

@ -31,10 +31,7 @@ define i8 @urem_1_shl(i8 %X, i8 %Y) {
define <vscale x 16 x i8> @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16 x i8> %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(
; CHECK-NEXT: [[BO0:%.*]] = mul nuw <vscale x 16 x i8> [[X:%.*]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 15, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
; CHECK-NEXT: [[BO1:%.*]] = mul <vscale x 16 x i8> [[X]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 5, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
; CHECK-NEXT: [[R:%.*]] = urem <vscale x 16 x i8> [[BO0]], [[BO1]]
; CHECK-NEXT: ret <vscale x 16 x i8> [[R]]
; CHECK-NEXT: ret <vscale x 16 x i8> zeroinitializer
;
%BO0 = mul nuw <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 15, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
%BO1 = mul <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 5, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
@ -44,10 +41,7 @@ define <vscale x 16 x i8> @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16
define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0(
; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 15
; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 5
; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: ret i8 [[R]]
; CHECK-NEXT: ret i8 0
;
%BO0 = mul nuw i8 %X, 15
%BO1 = mul i8 %X, 5
@ -70,9 +64,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0_fail_missing_flag(i8 %X) {
define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ(
; CHECK-NEXT: [[BO0:%.*]] = mul i8 [[X:%.*]], 3
; CHECK-NEXT: [[BO1:%.*]] = mul nuw i8 [[X]], 12
; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3
; CHECK-NEXT: ret i8 [[R]]
;
%BO0 = mul i8 %X, 3
@ -122,9 +114,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_fail_missing_flag(i8 %X) {
define i8 @urem_XY_XZ_with_CY_gt_CZ(i8 %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_gt_CZ(
; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 21
; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 6
; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 3
; CHECK-NEXT: ret i8 [[R]]
;
%BO0 = mul nuw i8 %X, 21
@ -242,10 +232,7 @@ define i8 @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_fail_missing_flags2(i8 %X, i8 %Y,
;; Signed Verions
define <vscale x 16 x i8> @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16 x i8> %X) {
; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(
; CHECK-NEXT: [[BO0:%.*]] = mul nsw <vscale x 16 x i8> [[X:%.*]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 15, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
; CHECK-NEXT: [[BO1:%.*]] = mul <vscale x 16 x i8> [[X]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 5, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
; CHECK-NEXT: [[R:%.*]] = srem <vscale x 16 x i8> [[BO0]], [[BO1]]
; CHECK-NEXT: ret <vscale x 16 x i8> [[R]]
; CHECK-NEXT: ret <vscale x 16 x i8> zeroinitializer
;
%BO0 = mul nsw <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 15, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
%BO1 = mul <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 5, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
@ -255,10 +242,7 @@ define <vscale x 16 x i8> @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16
define i8 @srem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0(
; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 9
; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 3
; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: ret i8 [[R]]
; CHECK-NEXT: ret i8 0
;
%BO0 = mul nsw i8 %X, 9
%BO1 = mul i8 %X, 3
@ -294,9 +278,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) {
define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) {
; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(
; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 5
; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 15
; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 5
; CHECK-NEXT: ret i8 [[R]]
;
%BO0 = mul nuw i8 %X, 5
@ -346,9 +328,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) {
define i8 @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(i8 %X) {
; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(
; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], 10
; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6
; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
; CHECK-NEXT: [[R:%.*]] = shl nuw nsw i8 [[X:%.*]], 2
; CHECK-NEXT: ret i8 [[R]]
;
%BO0 = mul nsw nuw i8 %X, 10