[mlir][TilingInterface] Modify TilingInterface methods to better return the state of the transformed IR.

Currently the `getTiledImplementation` and `generateResultTileValue`
return just `SmallVector<Operation *>` and `FailureOr<Value>`.

- For `getTiledImplementation` returning empty implies tiling wasnt
  done. There is also an implicit assumption that the tiled operation
  results correspond to the tiled values of the result of the original
  operation. This cannot handle cases where the tiled implementation
  might use multiple operations to compute the tiled value for the
  results of the untiled operation. Sometimes, the tiled operation
  might not directly give the tiled values, and might require casts,
  etc to get a replacement.
- For `generateResultTileValue`, it is assumed that the op defining
  the returned `Value` is the operation that represents the tiled
  computation. Again presence of casts, etc violate this.

Instead make these methods return
```
struct TilingResult {
  SmallVector<Operation *> tiledOps;
  SmallVector<Value> tiledValues;
};
```

The `tiledOps` represent the operations generated that are relevant
for subsequent transformations. The `tiledValues` represent the tiled
values for the results of the original operation. This better
transmits the state of the transformed IR.

As a consequence the following methods also return `FailureOr<TilingResult>`
- `tensor::replaceExtractSliceWithTiledProducer`
- `tensor::bubbleUpPadSlice`

Differential Revision: https://reviews.llvm.org/D145133
This commit is contained in:
Mahesh Ravishankar 2023-03-01 16:33:14 -08:00
parent a586c55100
commit 809e3d8c98
12 changed files with 164 additions and 122 deletions

View File

@ -16,6 +16,9 @@
#include "mlir/IR/Dialect.h"
namespace mlir {
struct TilingResult;
namespace tensor {
class PadOp;
@ -39,10 +42,10 @@ class PadOp;
/// to guard against the case that we might take a zero-sized slice from the
/// original source. For such cases, we `tensor.generate` to generate the
/// full tensor.
Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard = true);
FailureOr<TilingResult> bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard = true);
/// Registers external models for Tiling interface for tensor ops.
/// Currently, it registers:

View File

@ -13,6 +13,9 @@
#include "mlir/IR/PatternMatch.h"
namespace mlir {
struct TilingResult;
namespace tensor {
/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
@ -26,7 +29,7 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
/// use cases can inherit from this pattern and add necessary controls.
FailureOr<Value> replaceExtractSliceWithTiledProducer(
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice

View File

@ -21,6 +21,20 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
};
} // namespace mlir
/// Include the ODS generated interface header files.
#include "mlir/Interfaces/TilingInterface.h.inc"

View File

@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
/*retType=*/"SmallVector<Operation *>",
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
@ -119,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<Value>",
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,

View File

@ -431,16 +431,15 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
static SmallVector<Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
return {};
}
// Search the producer slices accessed within the containing operation.
@ -455,7 +454,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
if (it == tileableProducer->getUsers().end()) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find fusion opportunity for: " << *tileableProducer;
return nullptr;
return {};
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
@ -468,27 +467,29 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
FailureOr<TilingResult> tileAndFuseResult =
tileableProducer.generateResultTileValue(rewriter, resultNumber,
sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
return {};
}
for (auto tiledOp : tileAndFuseResult->tiledOps) {
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
}
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
.getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return fusedOp;
return tileAndFuseResult->tiledOps;
}
/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
@ -497,7 +498,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
/// right before its "extract" use. The tiled op is fused under the
/// `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
static SmallVector<Operation *>
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
@ -506,7 +508,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
return {};
}
// Search the first use by a "scf::ForallOp" user.
@ -520,7 +522,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (!forallOp || forallOp != containingOp) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find a use by the containing op: " << *tileableProducer;
return nullptr;
return {};
}
// Search the producer slices accessed within the containing
@ -542,7 +544,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (itBBArgUsers == bbArg.getUsers().end()) {
diag.attachNote(containingOp->getLoc())
<< "could not find fusion opportunity for bbArg: " << bbArg;
return nullptr;
return {};
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
@ -562,7 +564,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
destinationTensors))) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to get destination tensors for: " << *tileableProducer;
return nullptr;
return {};
}
IRMapping bvm;
@ -573,21 +575,19 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
// Tile the producer.
FailureOr<Value> tiledProducer =
FailureOr<TilingResult> tileAndFuseResult =
tileableProducerClone.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
return {};
}
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
@ -601,7 +601,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
destinationTensors.front());
});
return fusedOp;
return tileAndFuseResult->tiledOps;
}
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
@ -714,21 +714,21 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
SmallVector<Operation *> tiledOps =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
if (!tiledOps.empty()) {
LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.push_back(tiled);
fusedOps.append(tiledOps);
continue;
}
Operation *tiledContainingOpOperand =
SmallVector<Operation *> tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
if (!tiledContainingOpOperand.empty()) {
LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
<< *containingOp);
fusedOps.push_back(tiledContainingOpOperand);
fusedOps.append(tiledContainingOpOperand);
continue;
}

View File

@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
offsetsCopy[dimension] = offset;
// Create the part as it it were a single tile.
SmallVector<Operation *> tiled =
FailureOr<TilingResult> tilingResult =
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
assert(tiled.size() == 1 && "expected a single result from tiling");
auto part = cast<TilingInterface>(tiled.front());
// Insert the results back and populate the `results` list.
for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
resultOffsets, resultSizes)))
return nullptr;
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
b.getIndexAttr(1));
Value inserted = b.create<tensor::InsertSliceOp>(
loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
loc, result, resultOperands[index], resultOffsets, resultSizes,
resultStrides);
results.push_back(inserted);
}
return part;
// TODO: this part can be generalized maybe to not expect a single op.
assert(tilingResult->tiledOps.size() == 1 &&
"expected split part to return a single tiled operation");
return cast<TilingInterface>(tilingResult->tiledOps[0]);
}
std::pair<TilingInterface, TilingInterface>

View File

@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
}
// 4. Tile the cloned op and delete the clone.
SmallVector<Operation *> tiledOps =
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
tiledSizes);
b.eraseOp(clonedOp);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();
assert(tilingResult->tiledOps.size() == 1 &&
"expected a single produced tiled op");
tiledOp = tilingResult->tiledOps.front();
}
// 5. Parallel insert back into the result tensor.
@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 5. Tile the cloned op and delete the clone.
if (tileSizes.empty()) {
SmallVector<Operation *> tiledOps =
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(
b, tiledOffsets, tiledSizes);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();
tilingResults = tiledOp->getResults();
assert(tilingResult->tiledOps.size() == 1 &&
"expected a single produced tiled op");
tiledOp = tilingResult->tiledOps.front();
tilingResults = tilingResult->tiledValues;
} else {
LinalgTilingOptions options;
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(

View File

@ -111,7 +111,7 @@ struct LinalgOpTilingInterface
}
// Instantiate the tiled implementation of the operation.
SmallVector<Operation *>
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
@ -129,7 +129,7 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
return {tiledOp};
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
// Return the details of the output tile generated by the tiled
@ -160,10 +160,10 @@ struct LinalgOpTilingInterface
return success();
}
FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
@ -197,12 +197,15 @@ struct LinalgOpTilingInterface
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}
SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
b, iterationTileOffsets, iterationTileSizes);
if (tiledOp.size() != 1)
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
iterationTileSizes);
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
return tiledOp[0]->getResult(resultNumber);
return TilingResult{
tilingResult->tiledOps,
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
}
LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,

View File

@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return failure();
}
Operation *tiledPadOp =
FailureOr<TilingResult> tilingResult =
tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), zeroSliceGuard);
if (failed(tilingResult))
return failure();
// All shapes are static and the data source is actually used. Rewrite into
// pad(extract_slice(x)).
rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
return success();
}

View File

@ -251,18 +251,20 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
/// a destination passing style op.
static SmallVector<Value>
yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
Operation *tiledOp,
TilingResult tilingResult,
ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
MutableArrayRef<scf::ForOp> loops) {
SmallVector<Value> replacements =
yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
tileOffsetsList, tileSizesList, loops);
if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
auto innerMostLoop = loops.back();
SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
innerMostLoop.getRegionIterArgs());
for (auto tiledOp : tilingResult.tiledOps) {
if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
auto innerMostLoop = loops.back();
SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
innerMostLoop.getRegionIterArgs());
}
}
return replacements;
}
@ -345,9 +347,9 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
SmallVector<Operation *> tiledImplementation =
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
tilingResult.tiledOps.append(tiledImplementation);
tilingResult.tiledOps.append(tiledImplementation->tiledOps);
if (op->getNumResults() == 0) {
// nothing more to do.
return tilingResult;
@ -356,9 +358,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
if (tilingResult.loops.empty()) {
tilingResult.replacements = llvm::to_vector(
llvm::map_range(tiledImplementation[0]->getResults(),
[](OpResult result) -> Value { return result; }));
tilingResult.replacements = tiledImplementation->tiledValues;
return tilingResult;
}
@ -384,7 +384,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
return rewriter.notifyMatchFailure(op, "failed to get destinations");
tilingResult.replacements = yieldTiledValues(
rewriter, destinationTensors, tilingResult.tiledOps.back(),
rewriter, destinationTensors, tiledImplementation.value(),
resultOffsetsList, resultSizesList, tilingResult.loops);
LLVM_DEBUG({
@ -523,12 +523,13 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// 2. Generate the tiled implementation of the producer of the source
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
FailureOr<Value> fusedProducerValue =
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
fusableProducer);
if (failed(fusedProducerValue))
if (failed(tileAndFuseResult))
return std::nullopt;
rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value());
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);
// 3. If the slice is for a destination operand, for example,
//
@ -592,8 +593,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
outerMostLoop.setIterArg(iterArgNumber.value(),
dstOp.getTiedOpOperand(fusableProducer)->get());
}
if (auto dstOp = fusedProducerValue.value()
.getDefiningOp<DestinationStyleOpInterface>()) {
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
if (!dstOp)
continue;
scf::ForOp innerMostLoop = loops.back();
updateDestinationOperandsForTiledOp(
rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
@ -601,7 +604,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
}
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
fusedProducerValue.value()};
tileAndFuseResult->tiledValues[0]};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.

View File

@ -46,15 +46,15 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
return loopRanges;
}
SmallVector<Operation *>
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
Operation *result =
FailureOr<TilingResult> result =
tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
if (!result)
return {};
return {result};
if (failed(result))
return failure();
return result.value();
}
LogicalResult
@ -117,7 +117,7 @@ struct PackOpTiling
return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
}
SmallVector<Operation *>
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
@ -192,7 +192,8 @@ struct PackOpTiling
Operation *tiledPackOp = b.create<PackOp>(
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
return {tiledPackOp};
return TilingResult{{tiledPackOp},
SmallVector<Value>(tiledPackOp->getResults())};
}
LogicalResult
@ -353,7 +354,7 @@ struct UnPackOpTiling
/// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
/// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
/// can get the actual result.
SmallVector<Operation *>
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
@ -412,12 +413,13 @@ struct UnPackOpTiling
loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
if (isPerfectTilingCase)
return {tiledUnpackOp};
return TilingResult{{tiledUnpackOp},
SmallVector<Value>(tiledUnpackOp->getResults())};
Operation *extractSlice =
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
return {tiledUnpackOp, extractSlice};
return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
}
LogicalResult
@ -431,26 +433,29 @@ struct UnPackOpTiling
return success();
}
FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
return getTiledImplementation(op, b, offsets, sizes)
.back()
->getResult(resultNumber);
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
FailureOr<TilingResult> tilingResult =
getTiledImplementation(op, b, offsets, sizes);
if (failed(tilingResult))
return failure();
return tilingResult.value();
}
};
} // namespace
Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard) {
FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard) {
// Only constant padding value supported.
Value padValue = padOp.getConstantPaddingValue();
if (!padValue)
return nullptr;
return failure();
// Helper variables and functions for various arithmetic operations. These
// are used extensively for computing new offset/length and padding values.
@ -584,10 +589,9 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
RankedTensorType::get(shape, padOp.getResultType().getElementType());
// Insert cast to ensure that types match. (May be folded away.)
auto castResult = [&](Operation *op) -> Operation * {
Value val = op->getResult(0);
auto castResult = [&](Value val) -> Value {
if (resultType == val.getType())
return op;
return val;
return b.create<tensor::CastOp>(loc, resultType, val);
};
@ -601,7 +605,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
[&](OpBuilder &builder, Location gLoc, ValueRange indices) {
builder.create<tensor::YieldOp>(gLoc, padValue);
});
return castResult(generateOp);
return generateOp;
};
// Emit a SliceOp and a PadOp. Should not be used in cases where
@ -617,30 +621,38 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
// Cast result and return.
return castResult(newPadOp);
return newPadOp;
};
// Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
// the original data source x is not used.
if (hasZeroLen)
return createGenerateOp();
if (hasZeroLen) {
Operation *generateOp = createGenerateOp();
return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
}
// If there are dynamic dimensions: Generate an scf.if check to avoid
// creating SliceOps with result dimensions of size 0 at runtime.
if (generateZeroSliceGuard && dynHasZeroLenCond) {
Operation *thenOp;
Operation *elseOp;
auto result = b.create<scf::IfOp>(
loc, dynHasZeroLenCond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
thenOp = createGenerateOp();
b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
},
/*elseBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
elseOp = createPadOfExtractSlice();
b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
});
return result;
return TilingResult{{result}, SmallVector<Value>(result->getResults())};
}
return createPadOfExtractSlice();
Operation *newPadOp = createPadOfExtractSlice();
return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
}
void mlir::tensor::registerTilingInterfaceExternalModels(

View File

@ -20,7 +20,7 @@
using namespace mlir;
FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
if (!producerOp)
@ -32,7 +32,7 @@ FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
}))
return failure();
FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledResult))