Changes to SCFFuseProducerOfSliceResult
to also return the operations created during fusion.
This is follow up to https://reviews.llvm.org/D145133 that allows propogating information about ops that are fused back to the caller. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D146254
This commit is contained in:
parent
091422adc1
commit
3af1c48c66
|
@ -96,6 +96,7 @@ struct SCFTileAndFuseOptions {
|
|||
struct SCFFuseProducerOfSliceResult {
|
||||
OpResult origProducer; // Original untiled producer.
|
||||
Value tiledAndFusedProducer; // Tile and fused producer value.
|
||||
SmallVector<Operation *> tiledOps;
|
||||
};
|
||||
std::optional<SCFFuseProducerOfSliceResult>
|
||||
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
||||
|
|
|
@ -604,7 +604,8 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
|||
}
|
||||
}
|
||||
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
|
||||
tileAndFuseResult->tiledValues[0]};
|
||||
tileAndFuseResult->tiledValues[0],
|
||||
tileAndFuseResult->tiledOps};
|
||||
}
|
||||
|
||||
/// Reconstruct the fused producer from within the tiled-and-fused code.
|
||||
|
@ -612,7 +613,8 @@ void mlir::scf::yieldReplacementForFusedProducer(
|
|||
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
|
||||
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
|
||||
MutableArrayRef<scf::ForOp> loops) {
|
||||
auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
|
||||
auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
|
||||
fusedProducerInfo;
|
||||
SmallVector<Value> initValues;
|
||||
FailureOr<Value> initValue = tensor::getOrCreateDestination(
|
||||
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
|
||||
|
@ -623,8 +625,11 @@ void mlir::scf::yieldReplacementForFusedProducer(
|
|||
yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
|
||||
resultOffsets, resultSizes, loops);
|
||||
}
|
||||
if (auto dstStyleProducer =
|
||||
fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
|
||||
for (auto tileAndFusedOp : tileAndFusedOps) {
|
||||
auto dstStyleProducer =
|
||||
dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
|
||||
if (!dstStyleProducer)
|
||||
continue;
|
||||
Value dstValue =
|
||||
dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
|
||||
->get();
|
||||
|
|
Loading…
Reference in New Issue
Block a user