[mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand
This commit is contained in:
parent
d5b2c8e56d
commit
9437bf418a
|
@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformOpInterface, TransformEachOpTrait]> {
|
||||
[FunctionalStyleTransformOpTrait,
|
||||
MemoryEffectsOpInterface,
|
||||
TransformOpInterface,
|
||||
TransformEachOpTrait]> {
|
||||
let description = [{
|
||||
Decomposes named complex operations, such as higher-dimensional
|
||||
(depthwise) convolutions, into combinations of lower-dimensional equivalents
|
||||
|
@ -932,9 +934,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
|||
|
||||
def RewriteInDestinationPassingStyleOp : Op<
|
||||
Transform_Dialect, "structured.rewrite_in_destination_passing_style",
|
||||
[MemoryEffectsOpInterface,
|
||||
NavigationTransformOpTrait,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
[FunctionalStyleTransformOpTrait,
|
||||
MemoryEffectsOpInterface,
|
||||
TransformOpInterface,
|
||||
TransformEachOpTrait]> {
|
||||
let description = [{
|
||||
Rewrite a supported tensor operation that is not in destination-passing style
|
||||
into a form that is in destination-passing style.
|
||||
|
@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op<
|
|||
$target attr-dict
|
||||
`:` functional-type($target, results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation *target,
|
||||
::mlir::transform::ApplyToEachResultList &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::RewriteInDestinationPassingStyleOp::apply(
|
||||
transform::TransformResults &results, transform::TransformState &state) {
|
||||
transform::RewriteInDestinationPassingStyleOp::applyToOne(
|
||||
Operation *target, transform::ApplyToEachResultList &results,
|
||||
transform::TransformState &state) {
|
||||
SmallVector<Operation *> res;
|
||||
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
|
||||
for (Operation *target : targetOps) {
|
||||
IRRewriter rewriter(target->getContext());
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<Operation *> maybeResult =
|
||||
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
|
||||
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
|
||||
[&rewriter](auto op) {
|
||||
return rewriteInDestinationPassingStyle(rewriter, op);
|
||||
});
|
||||
if (failed(maybeResult))
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
res.push_back(*maybeResult);
|
||||
}
|
||||
results.set(getResult().cast<OpResult>(), res);
|
||||
IRRewriter rewriter(target->getContext());
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<Operation *> maybeResult =
|
||||
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
|
||||
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
|
||||
[&rewriter](auto op) {
|
||||
return rewriteInDestinationPassingStyle(rewriter, op);
|
||||
});
|
||||
if (failed(maybeResult))
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
results.push_back(*maybeResult);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user