Changes to SCFFuseProducerOfSliceResult to also return the operations created during fusion.

This is follow up to https://reviews.llvm.org/D145133 that allows
propogating information about ops that are fused back to the caller.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D146254
This commit is contained in:
Mahesh Ravishankar 2023-03-20 18:58:39 +00:00
parent 091422adc1
commit 3af1c48c66
2 changed files with 10 additions and 4 deletions

View File

@ -96,6 +96,7 @@ struct SCFTileAndFuseOptions {
struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,

View File

@ -604,7 +604,8 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
}
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
tileAndFuseResult->tiledValues[0]};
tileAndFuseResult->tiledValues[0],
tileAndFuseResult->tiledOps};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
@ -612,7 +613,8 @@ void mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<scf::ForOp> loops) {
auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
fusedProducerInfo;
SmallVector<Value> initValues;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
@ -623,8 +625,11 @@ void mlir::scf::yieldReplacementForFusedProducer(
yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
resultOffsets, resultSizes, loops);
}
if (auto dstStyleProducer =
fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
for (auto tileAndFusedOp : tileAndFusedOps) {
auto dstStyleProducer =
dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
if (!dstStyleProducer)
continue;
Value dstValue =
dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
->get();