[mlir][Vector] Add folding for masked reductions and vector.mask

This patch adds support for folding trivial masked reductions and
multi-reductions (e.g., multi-reductions with only parallel dims,
reductions of a single element, etc.). To support those foldings in
a composable way we also add support for folding different flavors of
empty vector.mask opertions.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144414
This commit is contained in:
Diego Caballero 2023-02-22 06:37:38 +00:00
parent 5f2618fe16
commit 51f235c444
5 changed files with 205 additions and 41 deletions

View File

@ -1172,7 +1172,7 @@ def Vector_ExtractStridedSliceOp :
static StringRef getSizesAttrStrName() { return "sizes"; }
static StringRef getStridesAttrStrName() { return "strides"; }
VectorType getSourceVectorType() {
return getVector().getType().cast<VectorType>();
return getVector().getType().cast<VectorType>();
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
@ -2382,9 +2382,11 @@ def Vector_MaskOp : Vector_Op<"mask", [
];
let extraClassDeclaration = [{
Block *getMaskBlock() { return &getMaskRegion().front(); }
static void ensureTerminator(Region &region, Builder &builder, Location loc);
}];
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

View File

@ -361,34 +361,50 @@ struct ElideUnitDimsInMultiDimReduction
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
// Masked reductions can't be folded until we can propagate the mask to the
// resulting operation.
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
if (maskableOp.isMasked())
return failure();
ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
for (const auto &dim : enumerate(shape)) {
if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
return failure();
}
// Vector mask setup.
OpBuilder::InsertionGuard guard(rewriter);
Operation *rootOp;
Value mask;
if (reductionOp.isMasked()) {
rewriter.setInsertionPoint(reductionOp.getMaskingOp());
rootOp = reductionOp.getMaskingOp();
mask = reductionOp.getMaskingOp().getMask();
} else {
rootOp = reductionOp;
}
Location loc = reductionOp.getLoc();
Value acc = reductionOp.getAcc();
Value cast;
if (reductionOp.getDestType().isa<VectorType>()) {
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
if (mask) {
VectorType newMaskType =
VectorType::get(dstVecType.getShape(), rewriter.getI1Type());
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
}
cast = rewriter.create<vector::ShapeCastOp>(
loc, reductionOp.getDestType(), reductionOp.getSource());
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
auto zeroAttr =
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
if (mask)
mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
mask, zeroAttr);
cast = rewriter.create<vector::ExtractOp>(
loc, reductionOp.getDestType(), reductionOp.getSource(),
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
}
Value result = vector::makeArithReduction(rewriter, loc,
reductionOp.getKind(), acc, cast);
rewriter.replaceOp(reductionOp, result);
Value result = vector::makeArithReduction(
rewriter, loc, reductionOp.getKind(), acc, cast, mask);
rewriter.replaceOp(rootOp, result);
return success();
}
};
@ -524,11 +540,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
LogicalResult matchAndRewrite(ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
// Masked reductions can't be folded until we can propagate the mask to the
// resulting operation.
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
if (maskableOp.isMasked())
return failure();
// Vector mask setup.
OpBuilder::InsertionGuard guard(rewriter);
auto maskableOp =
cast<vector::MaskableOpInterface>(reductionOp.getOperation());
Operation *rootOp;
Value mask;
if (maskableOp.isMasked()) {
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
rootOp = maskableOp.getMaskingOp();
mask = maskableOp.getMaskingOp().getMask();
} else {
rootOp = reductionOp;
}
auto vectorType = reductionOp.getSourceVectorType();
if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
@ -537,8 +561,14 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Location loc = reductionOp.getLoc();
Value result;
if (vectorType.getRank() == 0) {
if (mask)
mask = rewriter.create<ExtractElementOp>(loc, mask);
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
} else {
if (mask) {
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
rewriter.getI64ArrayAttr(0));
}
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
@ -546,9 +576,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
result, acc);
result, acc, mask);
rewriter.replaceOp(reductionOp, result);
rewriter.replaceOp(rootOp, result);
return success();
}
};
@ -5465,7 +5495,7 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
// Print single masked operation and skip terminator.
p << " { ";
Block *singleBlock = &getMaskRegion().getBlocks().front();
if (singleBlock && singleBlock->getOperations().size() > 1)
if (singleBlock && singleBlock->getOperations().size() >= 1)
p.printCustomOrGenericOp(&singleBlock->front());
p << " }";
@ -5481,33 +5511,49 @@ void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
MaskOp>::ensureTerminator(region, builder, loc);
// Keep the default yield terminator if the number of masked operations is not
// the expected. This case will trigger a verification failure.
if (region.front().getOperations().size() != 2)
Block &block = region.front();
if (block.getOperations().size() != 2)
return;
// Replace default yield terminator with a new one that returns the results
// from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &region.front().front();
Operation *oldYieldOp = &region.front().back();
Operation *maskedOp = &block.front();
Operation *oldYieldOp = &block.back();
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
// Empty vector.mask op.
if (maskedOp == oldYieldOp)
return;
opBuilder.setInsertionPoint(oldYieldOp);
opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
oldYieldOp->dropAllReferences();
oldYieldOp->erase();
return;
}
LogicalResult MaskOp::verify() {
// Structural checks.
Block &block = getMaskRegion().getBlocks().front();
if (block.getOperations().size() < 2)
return emitOpError("expects an operation to mask");
if (block.getOperations().size() < 1)
return emitOpError("expects a terminator within the mask region");
if (block.getOperations().size() > 2)
return emitOpError("expects only one operation to mask");
// Terminator checks.
auto terminator = dyn_cast<vector::YieldOp>(block.back());
if (!terminator)
return emitOpError("expects a terminator within the mask region");
if (terminator->getNumOperands() != getNumResults())
return emitOpError(
"expects number of results to match mask region yielded values");
auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
// Empty vector.mask. Nothing else to check.
if (!maskableOp)
return emitOpError("expects a maskable operation");
return success();
// Result checks.
if (maskableOp->getNumResults() != getNumResults())
@ -5545,10 +5591,47 @@ LogicalResult MaskOp::verify() {
return success();
}
// Elides empty vector.mask operations with or without return values. Propagates
// the yielded values by the vector.yield terminator, if any, or erases the op,
// otherwise.
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const override {
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
if (maskingOp.getMaskableOp())
return failure();
Block *block = maskOp.getMaskBlock();
if (block->getOperations().size() > 1)
return failure();
auto terminator = cast<vector::YieldOp>(block->front());
if (terminator.getNumOperands() == 0)
rewriter.eraseOp(maskOp);
else
rewriter.replaceOp(maskOp, terminator.getOperands());
return success();
}
};
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ElideEmptyMaskOp>(context);
}
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
Operation *MaskOp::getMaskableOp() {
Block *block = getMaskBlock();
if (block->getOperations().size() < 2)
return nullptr;
return &block->front();
}
/// Returns true if 'vector.mask' has a passthru value.
bool MaskOp::hasPassthru() { return getPassthru() != Value(); }

View File

@ -1372,6 +1372,16 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
// -----
// CHECK-LABEL: func @masked_vector_multi_reduction_single_parallel(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %{{.*}}: vector<2xf32>,
func.func @masked_vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>, %mask: vector<2xi1>) -> vector<2xf32> {
%0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32> } : vector<2xi1> -> vector<2xf32>
// CHECK: return %[[VAL_0]] : vector<2xf32>
return %0 : vector<2xf32>
}
// -----
// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
@ -1385,14 +1395,17 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
// -----
// Masked reduction can't be folded.
// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
// CHECK-SAME: %[[VAL_0:.*]]: vector<5x1x4x1x20xf32>, %[[VAL_1:.*]]: vector<5x4x20xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<5x1x4x1x20xi1>)
func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
%acc: vector<5x4x20xf32>,
%mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
// CHECK: vector.mask %{{.*}} { vector.multi_reduction <mul>
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<5x1x4x1x20xi1> to vector<5x4x20xi1>
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : vector<5x4x20xf32>
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<5x4x20xi1>, vector<5x4x20xf32>
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
return %0 : vector<5x4x20xf32>
}
@ -1424,6 +1437,20 @@ func.func @vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x
// -----
// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions_single_elem(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x1x1xi1>)
func.func @masked_vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x1x1xf32>, %acc: f32, %mask: vector<1x1x1xi1>) -> f32 {
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0, 0, 0] : vector<1x1x1xi1>
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0, 0, 0] : vector<1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : f32
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [0,1,2] : vector<1x1x1xf32> to f32 } : vector<1x1x1xi1> -> f32
return %0 : f32
}
// -----
// CHECK-LABEL: func @insert_strided_slice_full_range
// CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {
@ -1937,6 +1964,17 @@ func.func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 {
// -----
// CHECK-LABEL: func @masked_reduce_one_element_vector_extract
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: vector<1xi1>)
func.func @masked_reduce_one_element_vector_extract(%a : vector<1xf32>, %mask : vector<1xi1>) -> f32 {
// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
%s = vector.mask %mask { vector.reduction <add>, %a : vector<1xf32> into f32 }
: vector<1xi1> -> f32
return %s : f32
}
// -----
// CHECK-LABEL: func @reduce_one_element_vector_addf
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
@ -1950,10 +1988,15 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
// CHECK: vector.mask %{{.*}} { vector.reduction <add>
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
%b: f32,
%mask: vector<1xi1>) -> f32 {
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0] : vector<1xi1>
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_1]] : f32
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_1]] : f32
%s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
: vector<1xi1> -> f32
return %s : f32
@ -2167,3 +2210,25 @@ func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
return %0 : f32
}
// -----
// CHECK-LABEL: func @empty_vector_mask
func.func @empty_vector_mask(%mask : vector<8xi1>) {
// CHECK-NOT: vector.mask
vector.mask %mask { } : vector<8xi1>
return
}
// -----
// CHECK-LABEL: func @empty_vector_mask_with_return
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
// CHECK-NOT: vector.mask
// CHECK: return %[[IN]] : vector<8xf32>
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32>
}

View File

@ -1604,13 +1604,6 @@ func.func @warp_mismatch_rank(%laneid: index) {
// -----
func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 {
// expected-error@+1 {{'vector.mask' op expects an operation to mask}}
vector.mask %m0 { } : vector<16xi1>
}
// -----
func.func @vector_mask_multiple_ops(%t0: tensor<?xf32>, %t1: tensor<?xf32>, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op expects only one operation to mask}}

View File

@ -860,6 +860,27 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
return
}
// CHECK-LABEL: func @vector_mask_empty
func.func @vector_mask_empty(%m0: vector<16xi1>) {
// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
vector.mask %m0 { } : vector<16xi1>
return
}
// CHECK-LABEL: func @vector_mask_empty_with_yield
func.func @vector_mask_empty_with_yield(%m0: vector<16xi1>) {
// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
vector.mask %m0 { vector.yield } : vector<16xi1>
return
}
// CHECK-LABEL: func @vector_mask_empty_return
func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -> vector<16xf32> {
// CHECK: vector.mask %{{.*}} { vector.yield {{.*}} : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
%0 = vector.mask %m0 { vector.yield %arg0 : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>