[mlir][Vector] Add folding for masked reductions and vector.mask
This patch adds support for folding trivial masked reductions and multi-reductions (e.g., multi-reductions with only parallel dims, reductions of a single element, etc.). To support those foldings in a composable way we also add support for folding different flavors of empty vector.mask opertions. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D144414
This commit is contained in:
parent
5f2618fe16
commit
51f235c444
|
@ -1172,7 +1172,7 @@ def Vector_ExtractStridedSliceOp :
|
|||
static StringRef getSizesAttrStrName() { return "sizes"; }
|
||||
static StringRef getStridesAttrStrName() { return "strides"; }
|
||||
VectorType getSourceVectorType() {
|
||||
return getVector().getType().cast<VectorType>();
|
||||
return getVector().getType().cast<VectorType>();
|
||||
}
|
||||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||
bool hasNonUnitStrides() {
|
||||
|
@ -2382,9 +2382,11 @@ def Vector_MaskOp : Vector_Op<"mask", [
|
|||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block *getMaskBlock() { return &getMaskRegion().front(); }
|
||||
static void ensureTerminator(Region ®ion, Builder &builder, Location loc);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
|
|
@ -361,34 +361,50 @@ struct ElideUnitDimsInMultiDimReduction
|
|||
|
||||
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Masked reductions can't be folded until we can propagate the mask to the
|
||||
// resulting operation.
|
||||
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
|
||||
if (maskableOp.isMasked())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
|
||||
for (const auto &dim : enumerate(shape)) {
|
||||
if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Vector mask setup.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Operation *rootOp;
|
||||
Value mask;
|
||||
if (reductionOp.isMasked()) {
|
||||
rewriter.setInsertionPoint(reductionOp.getMaskingOp());
|
||||
rootOp = reductionOp.getMaskingOp();
|
||||
mask = reductionOp.getMaskingOp().getMask();
|
||||
} else {
|
||||
rootOp = reductionOp;
|
||||
}
|
||||
|
||||
Location loc = reductionOp.getLoc();
|
||||
Value acc = reductionOp.getAcc();
|
||||
Value cast;
|
||||
if (reductionOp.getDestType().isa<VectorType>()) {
|
||||
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
|
||||
if (mask) {
|
||||
VectorType newMaskType =
|
||||
VectorType::get(dstVecType.getShape(), rewriter.getI1Type());
|
||||
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
|
||||
}
|
||||
cast = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, reductionOp.getDestType(), reductionOp.getSource());
|
||||
} else {
|
||||
// This means we are reducing all the dimensions, and all reduction
|
||||
// dimensions are of size 1. So a simple extraction would do.
|
||||
auto zeroAttr =
|
||||
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
|
||||
if (mask)
|
||||
mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
|
||||
mask, zeroAttr);
|
||||
cast = rewriter.create<vector::ExtractOp>(
|
||||
loc, reductionOp.getDestType(), reductionOp.getSource(),
|
||||
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
|
||||
loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
|
||||
}
|
||||
|
||||
Value result = vector::makeArithReduction(rewriter, loc,
|
||||
reductionOp.getKind(), acc, cast);
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
Value result = vector::makeArithReduction(
|
||||
rewriter, loc, reductionOp.getKind(), acc, cast, mask);
|
||||
rewriter.replaceOp(rootOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -524,11 +540,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(ReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Masked reductions can't be folded until we can propagate the mask to the
|
||||
// resulting operation.
|
||||
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
|
||||
if (maskableOp.isMasked())
|
||||
return failure();
|
||||
// Vector mask setup.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
auto maskableOp =
|
||||
cast<vector::MaskableOpInterface>(reductionOp.getOperation());
|
||||
Operation *rootOp;
|
||||
Value mask;
|
||||
if (maskableOp.isMasked()) {
|
||||
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
|
||||
rootOp = maskableOp.getMaskingOp();
|
||||
mask = maskableOp.getMaskingOp().getMask();
|
||||
} else {
|
||||
rootOp = reductionOp;
|
||||
}
|
||||
|
||||
auto vectorType = reductionOp.getSourceVectorType();
|
||||
if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
|
||||
|
@ -537,8 +561,14 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
|
|||
Location loc = reductionOp.getLoc();
|
||||
Value result;
|
||||
if (vectorType.getRank() == 0) {
|
||||
if (mask)
|
||||
mask = rewriter.create<ExtractElementOp>(loc, mask);
|
||||
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
|
||||
} else {
|
||||
if (mask) {
|
||||
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
|
||||
rewriter.getI64ArrayAttr(0));
|
||||
}
|
||||
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
|
||||
reductionOp.getVector(),
|
||||
rewriter.getI64ArrayAttr(0));
|
||||
|
@ -546,9 +576,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
|
|||
|
||||
if (Value acc = reductionOp.getAcc())
|
||||
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
|
||||
result, acc);
|
||||
result, acc, mask);
|
||||
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
rewriter.replaceOp(rootOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -5465,7 +5495,7 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
|
|||
// Print single masked operation and skip terminator.
|
||||
p << " { ";
|
||||
Block *singleBlock = &getMaskRegion().getBlocks().front();
|
||||
if (singleBlock && singleBlock->getOperations().size() > 1)
|
||||
if (singleBlock && singleBlock->getOperations().size() >= 1)
|
||||
p.printCustomOrGenericOp(&singleBlock->front());
|
||||
p << " }";
|
||||
|
||||
|
@ -5481,33 +5511,49 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
|
|||
MaskOp>::ensureTerminator(region, builder, loc);
|
||||
// Keep the default yield terminator if the number of masked operations is not
|
||||
// the expected. This case will trigger a verification failure.
|
||||
if (region.front().getOperations().size() != 2)
|
||||
Block &block = region.front();
|
||||
if (block.getOperations().size() != 2)
|
||||
return;
|
||||
|
||||
// Replace default yield terminator with a new one that returns the results
|
||||
// from the masked operation.
|
||||
OpBuilder opBuilder(builder.getContext());
|
||||
Operation *maskedOp = ®ion.front().front();
|
||||
Operation *oldYieldOp = ®ion.front().back();
|
||||
Operation *maskedOp = &block.front();
|
||||
Operation *oldYieldOp = &block.back();
|
||||
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
|
||||
|
||||
// Empty vector.mask op.
|
||||
if (maskedOp == oldYieldOp)
|
||||
return;
|
||||
|
||||
opBuilder.setInsertionPoint(oldYieldOp);
|
||||
opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
|
||||
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
|
||||
oldYieldOp->dropAllReferences();
|
||||
oldYieldOp->erase();
|
||||
return;
|
||||
}
|
||||
|
||||
LogicalResult MaskOp::verify() {
|
||||
// Structural checks.
|
||||
Block &block = getMaskRegion().getBlocks().front();
|
||||
if (block.getOperations().size() < 2)
|
||||
return emitOpError("expects an operation to mask");
|
||||
if (block.getOperations().size() < 1)
|
||||
return emitOpError("expects a terminator within the mask region");
|
||||
if (block.getOperations().size() > 2)
|
||||
return emitOpError("expects only one operation to mask");
|
||||
|
||||
// Terminator checks.
|
||||
auto terminator = dyn_cast<vector::YieldOp>(block.back());
|
||||
if (!terminator)
|
||||
return emitOpError("expects a terminator within the mask region");
|
||||
|
||||
if (terminator->getNumOperands() != getNumResults())
|
||||
return emitOpError(
|
||||
"expects number of results to match mask region yielded values");
|
||||
|
||||
auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
|
||||
// Empty vector.mask. Nothing else to check.
|
||||
if (!maskableOp)
|
||||
return emitOpError("expects a maskable operation");
|
||||
return success();
|
||||
|
||||
// Result checks.
|
||||
if (maskableOp->getNumResults() != getNumResults())
|
||||
|
@ -5545,10 +5591,47 @@ LogicalResult MaskOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Elides empty vector.mask operations with or without return values. Propagates
|
||||
// the yielded values by the vector.yield terminator, if any, or erases the op,
|
||||
// otherwise.
|
||||
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(MaskOp maskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
|
||||
if (maskingOp.getMaskableOp())
|
||||
return failure();
|
||||
|
||||
Block *block = maskOp.getMaskBlock();
|
||||
if (block->getOperations().size() > 1)
|
||||
return failure();
|
||||
|
||||
auto terminator = cast<vector::YieldOp>(block->front());
|
||||
if (terminator.getNumOperands() == 0)
|
||||
rewriter.eraseOp(maskOp);
|
||||
else
|
||||
rewriter.replaceOp(maskOp, terminator.getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ElideEmptyMaskOp>(context);
|
||||
}
|
||||
|
||||
// MaskingOpInterface definitions.
|
||||
|
||||
/// Returns the operation masked by this 'vector.mask'.
|
||||
Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
|
||||
Operation *MaskOp::getMaskableOp() {
|
||||
Block *block = getMaskBlock();
|
||||
if (block->getOperations().size() < 2)
|
||||
return nullptr;
|
||||
|
||||
return &block->front();
|
||||
}
|
||||
|
||||
/// Returns true if 'vector.mask' has a passthru value.
|
||||
bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
|
||||
|
|
|
@ -1372,6 +1372,16 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @masked_vector_multi_reduction_single_parallel(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %{{.*}}: vector<2xf32>,
|
||||
func.func @masked_vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>, %mask: vector<2xi1>) -> vector<2xf32> {
|
||||
%0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32> } : vector<2xi1> -> vector<2xf32>
|
||||
// CHECK: return %[[VAL_0]] : vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
|
||||
// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
|
||||
func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
|
||||
|
@ -1385,14 +1395,17 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
|
|||
|
||||
// -----
|
||||
|
||||
// Masked reduction can't be folded.
|
||||
|
||||
// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<5x1x4x1x20xf32>, %[[VAL_1:.*]]: vector<5x4x20xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: vector<5x1x4x1x20xi1>)
|
||||
func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
|
||||
%acc: vector<5x4x20xf32>,
|
||||
%mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
|
||||
// CHECK: vector.mask %{{.*}} { vector.multi_reduction <mul>
|
||||
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
|
||||
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<5x1x4x1x20xi1> to vector<5x4x20xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : vector<5x4x20xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<5x4x20xi1>, vector<5x4x20xf32>
|
||||
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
|
||||
vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
|
||||
return %0 : vector<5x4x20xf32>
|
||||
}
|
||||
|
@ -1424,6 +1437,20 @@ func.func @vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions_single_elem(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x1xf32>, %[[VAL_1:.*]]: f32,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x1x1xi1>)
|
||||
func.func @masked_vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x1x1xf32>, %acc: f32, %mask: vector<1x1x1xi1>) -> f32 {
|
||||
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0, 0, 0] : vector<1x1x1xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0, 0, 0] : vector<1x1x1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : f32
|
||||
%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [0,1,2] : vector<1x1x1xf32> to f32 } : vector<1x1x1xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_strided_slice_full_range
|
||||
// CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
|
||||
func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {
|
||||
|
@ -1937,6 +1964,17 @@ func.func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @masked_reduce_one_element_vector_extract
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: vector<1xi1>)
|
||||
func.func @masked_reduce_one_element_vector_extract(%a : vector<1xf32>, %mask : vector<1xi1>) -> f32 {
|
||||
// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
|
||||
%s = vector.mask %mask { vector.reduction <add>, %a : vector<1xf32> into f32 }
|
||||
: vector<1xi1> -> f32
|
||||
return %s : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_one_element_vector_addf
|
||||
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
|
||||
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
|
||||
|
@ -1950,10 +1988,15 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
|
||||
// CHECK: vector.mask %{{.*}} { vector.reduction <add>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
|
||||
func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
|
||||
%b: f32,
|
||||
%mask: vector<1xi1>) -> f32 {
|
||||
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0] : vector<1xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_1]] : f32
|
||||
%s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
|
||||
: vector<1xi1> -> f32
|
||||
return %s : f32
|
||||
|
@ -2167,3 +2210,25 @@ func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
|
|||
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @empty_vector_mask
|
||||
func.func @empty_vector_mask(%mask : vector<8xi1>) {
|
||||
// CHECK-NOT: vector.mask
|
||||
vector.mask %mask { } : vector<8xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @empty_vector_mask_with_return
|
||||
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
|
||||
func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
|
||||
// CHECK-NOT: vector.mask
|
||||
// CHECK: return %[[IN]] : vector<8xf32>
|
||||
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1604,13 +1604,6 @@ func.func @warp_mismatch_rank(%laneid: index) {
|
|||
|
||||
// -----
|
||||
|
||||
func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 {
|
||||
// expected-error@+1 {{'vector.mask' op expects an operation to mask}}
|
||||
vector.mask %m0 { } : vector<16xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @vector_mask_multiple_ops(%t0: tensor<?xf32>, %t1: tensor<?xf32>, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) {
|
||||
%ft0 = arith.constant 0.0 : f32
|
||||
// expected-error@+1 {{'vector.mask' op expects only one operation to mask}}
|
||||
|
|
|
@ -860,6 +860,27 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_mask_empty
|
||||
func.func @vector_mask_empty(%m0: vector<16xi1>) {
|
||||
// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
|
||||
vector.mask %m0 { } : vector<16xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_mask_empty_with_yield
|
||||
func.func @vector_mask_empty_with_yield(%m0: vector<16xi1>) {
|
||||
// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
|
||||
vector.mask %m0 { vector.yield } : vector<16xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_mask_empty_return
|
||||
func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK: vector.mask %{{.*}} { vector.yield {{.*}} : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
|
||||
%0 = vector.mask %m0 { vector.yield %arg0 : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_scalable_insert(
|
||||
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
|
||||
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
|
||||
|
|
Loading…
Reference in New Issue
Block a user