[mlir][spirv] Fix scf.yield pattern conversion

Only rewrite `scf.yield` when the parent op is supported by
scf-to-spirv.

Fixes: #61380, #61107, #61148

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D146080
This commit is contained in:
Jakub Kuderski 2023-03-14 18:47:33 -04:00
parent ab5eae0164
commit dfee4c7fb0
2 changed files with 35 additions and 4 deletions

View File

@ -291,18 +291,28 @@ public:
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
// If the region is return values, store each value into the associated
Operation *parent = terminatorOp->getParentOp();
// TODO: Implement conversion for the remaining `scf` ops.
if (parent->getDialect()->getNamespace() ==
scf::SCFDialect::getDialectNamespace() &&
!isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
return rewriter.notifyMatchFailure(
terminatorOp,
llvm::formatv("conversion not supported for parent op: '{0}'",
parent->getName()));
// If the region return values, store each value into the associated
// VariableOp created during lowering of the parent region.
if (!operands.empty()) {
auto &allocas =
scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
auto &allocas = scfToSPIRVContext->outputVars[parent];
if (allocas.size() != operands.size())
return failure();
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
if (isa<spirv::LoopOp>(parent)) {
// For loops we also need to update the branch jumping back to the
// header.
auto br = cast<spirv::BranchOp>(

View File

@ -0,0 +1,21 @@
// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
// `scf.parallel` conversion is not supported yet.
// Make sure that we do not accidentally invalidate this functio by removing
// `scf.yield`.
// CHECK-LABEL: func.func @func
// CHECK: scf.parallel
// CHECK-NEXT: spirv.Constant
// CHECK-NEXT: memref.store
// CHECK-NEXT: scf.yield
// CHECK: spirv.Return
func.func @func(%arg0: i64) {
%0 = arith.index_cast %arg0 : i64 to index
%alloc = memref.alloc() : memref<16xf32>
scf.parallel (%arg1) = (%0) to (%0) step (%0) {
%cst = arith.constant 1.000000e+00 : f32
memref.store %cst, %alloc[%arg1] : memref<16xf32>
scf.yield
}
return
}