[mlir][spirv] Fix scf.yield pattern conversion
Only rewrite `scf.yield` when the parent op is supported by scf-to-spirv. Fixes: #61380, #61107, #61148 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D146080
This commit is contained in:
parent
ab5eae0164
commit
dfee4c7fb0
|
@ -291,18 +291,28 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
ValueRange operands = adaptor.getOperands();
|
ValueRange operands = adaptor.getOperands();
|
||||||
|
|
||||||
// If the region is return values, store each value into the associated
|
Operation *parent = terminatorOp->getParentOp();
|
||||||
|
|
||||||
|
// TODO: Implement conversion for the remaining `scf` ops.
|
||||||
|
if (parent->getDialect()->getNamespace() ==
|
||||||
|
scf::SCFDialect::getDialectNamespace() &&
|
||||||
|
!isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
terminatorOp,
|
||||||
|
llvm::formatv("conversion not supported for parent op: '{0}'",
|
||||||
|
parent->getName()));
|
||||||
|
|
||||||
|
// If the region return values, store each value into the associated
|
||||||
// VariableOp created during lowering of the parent region.
|
// VariableOp created during lowering of the parent region.
|
||||||
if (!operands.empty()) {
|
if (!operands.empty()) {
|
||||||
auto &allocas =
|
auto &allocas = scfToSPIRVContext->outputVars[parent];
|
||||||
scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
|
|
||||||
if (allocas.size() != operands.size())
|
if (allocas.size() != operands.size())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loc = terminatorOp.getLoc();
|
auto loc = terminatorOp.getLoc();
|
||||||
for (unsigned i = 0, e = operands.size(); i < e; i++)
|
for (unsigned i = 0, e = operands.size(); i < e; i++)
|
||||||
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
|
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
|
||||||
if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
|
if (isa<spirv::LoopOp>(parent)) {
|
||||||
// For loops we also need to update the branch jumping back to the
|
// For loops we also need to update the branch jumping back to the
|
||||||
// header.
|
// header.
|
||||||
auto br = cast<spirv::BranchOp>(
|
auto br = cast<spirv::BranchOp>(
|
||||||
|
|
21
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Normal file
21
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// `scf.parallel` conversion is not supported yet.
|
||||||
|
// Make sure that we do not accidentally invalidate this functio by removing
|
||||||
|
// `scf.yield`.
|
||||||
|
// CHECK-LABEL: func.func @func
|
||||||
|
// CHECK: scf.parallel
|
||||||
|
// CHECK-NEXT: spirv.Constant
|
||||||
|
// CHECK-NEXT: memref.store
|
||||||
|
// CHECK-NEXT: scf.yield
|
||||||
|
// CHECK: spirv.Return
|
||||||
|
func.func @func(%arg0: i64) {
|
||||||
|
%0 = arith.index_cast %arg0 : i64 to index
|
||||||
|
%alloc = memref.alloc() : memref<16xf32>
|
||||||
|
scf.parallel (%arg1) = (%0) to (%0) step (%0) {
|
||||||
|
%cst = arith.constant 1.000000e+00 : f32
|
||||||
|
memref.store %cst, %alloc[%arg1] : memref<16xf32>
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user