Don't fail if unable to promote loops during unrolling

When the unroll factor is 1, we should only fail "unrolling" when the trip count also is determined to be 1 and it is unable to be promoted.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D115365
This commit is contained in:
Tyler Augustine 2022-01-10 21:42:03 +00:00 committed by Mehdi Amini
parent 2154dbaa59
commit 87a9be2a74
2 changed files with 37 additions and 8 deletions

View File

@ -1182,8 +1182,14 @@ LogicalResult mlir::loopUnrollByFactor(
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "unroll factor should be positive");
if (unrollFactor == 1)
return promoteIfSingleIteration(forOp);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollFactor == 1) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() == 1 &&
failed(promoteIfSingleIteration(forOp)))
return failure();
return success();
}
// Nothing in the loop body other than the terminator.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@ -1191,7 +1197,6 @@ LogicalResult mlir::loopUnrollByFactor(
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO: option to specify cleanup loop unrolling.
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
return failure();
@ -1237,8 +1242,6 @@ LogicalResult mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
if (unrollFactor == 1)
return promoteIfSingleIteration(forOp);
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@ -1264,6 +1267,13 @@ LogicalResult mlir::loopUnrollByFactor(
assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
"expected positive loop bounds and step");
int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
if (unrollFactor == 1) {
if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
return failure();
return success();
}
int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
assert(upperBoundUnrolledCst <= ubCst);
@ -1403,14 +1413,19 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
if (unrollJamFactor == 1)
return promoteIfSingleIteration(forOp);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollJamFactor == 1) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() == 1 &&
failed(promoteIfSingleIteration(forOp)))
return failure();
return success();
}
// Nothing in the loop body other than the terminator.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
return success();
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
// If the trip count is lower than the unroll jam factor, no unroll jam.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor) {

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
// CHECK-LABEL: scf_loop_unroll_single
func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@ -42,3 +43,16 @@ func @scf_loop_unroll_double_symbolic_ub(%arg0 : f32, %arg1 : f32, %n : index) -
// CHECK: }
// CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1
}
// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_promote
func @scf_loop_unroll_factor_1_promote() -> () {
%step = arith.constant 1 : index
%lo = arith.constant 0 : index
%hi = arith.constant 1 : index
scf.for %i = %lo to %hi step %step {
%x = "test.foo"(%i) : (index) -> i32
}
return
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
}