[mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand

This commit is contained in:
Nicolas Vasilache 2023-03-21 01:04:04 -07:00
parent d5b2c8e56d
commit 9437bf418a
2 changed files with 29 additions and 22 deletions

View File

@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
//===----------------------------------------------------------------------===//
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait]> {
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
@ -932,9 +934,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
def RewriteInDestinationPassingStyleOp : Op<
Transform_Dialect, "structured.rewrite_in_destination_passing_style",
[MemoryEffectsOpInterface,
NavigationTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait]> {
let description = [{
Rewrite a supported tensor operation that is not in destination-passing style
into a form that is in destination-passing style.
@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op<
$target attr-dict
`:` functional-type($target, results)
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
//===----------------------------------------------------------------------===//

View File

@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::RewriteInDestinationPassingStyleOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
transform::RewriteInDestinationPassingStyleOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<Operation *> res;
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
for (Operation *target : targetOps) {
IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
if (failed(maybeResult))
return emitDefaultSilenceableFailure(target);
res.push_back(*maybeResult);
}
results.set(getResult().cast<OpResult>(), res);
IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
if (failed(maybeResult))
return emitDefaultSilenceableFailure(target);
results.push_back(*maybeResult);
return DiagnosedSilenceableFailure::success();
}