[mlir][spirv] Fix scf.yield pattern conversion
Only rewrite `scf.yield` when the parent op is supported by scf-to-spirv. Fixes: #61380, #61107, #61148 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D146080
This commit is contained in:
parent
ab5eae0164
commit
dfee4c7fb0
|
@ -291,18 +291,28 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
ValueRange operands = adaptor.getOperands();
|
||||
|
||||
// If the region is return values, store each value into the associated
|
||||
Operation *parent = terminatorOp->getParentOp();
|
||||
|
||||
// TODO: Implement conversion for the remaining `scf` ops.
|
||||
if (parent->getDialect()->getNamespace() ==
|
||||
scf::SCFDialect::getDialectNamespace() &&
|
||||
!isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
|
||||
return rewriter.notifyMatchFailure(
|
||||
terminatorOp,
|
||||
llvm::formatv("conversion not supported for parent op: '{0}'",
|
||||
parent->getName()));
|
||||
|
||||
// If the region return values, store each value into the associated
|
||||
// VariableOp created during lowering of the parent region.
|
||||
if (!operands.empty()) {
|
||||
auto &allocas =
|
||||
scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
|
||||
auto &allocas = scfToSPIRVContext->outputVars[parent];
|
||||
if (allocas.size() != operands.size())
|
||||
return failure();
|
||||
|
||||
auto loc = terminatorOp.getLoc();
|
||||
for (unsigned i = 0, e = operands.size(); i < e; i++)
|
||||
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
|
||||
if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
|
||||
if (isa<spirv::LoopOp>(parent)) {
|
||||
// For loops we also need to update the branch jumping back to the
|
||||
// header.
|
||||
auto br = cast<spirv::BranchOp>(
|
||||
|
|
21
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Normal file
21
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Normal file
|
@ -0,0 +1,21 @@
|
|||
// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
// `scf.parallel` conversion is not supported yet.
|
||||
// Make sure that we do not accidentally invalidate this functio by removing
|
||||
// `scf.yield`.
|
||||
// CHECK-LABEL: func.func @func
|
||||
// CHECK: scf.parallel
|
||||
// CHECK-NEXT: spirv.Constant
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK: spirv.Return
|
||||
func.func @func(%arg0: i64) {
|
||||
%0 = arith.index_cast %arg0 : i64 to index
|
||||
%alloc = memref.alloc() : memref<16xf32>
|
||||
scf.parallel (%arg1) = (%0) to (%0) step (%0) {
|
||||
%cst = arith.constant 1.000000e+00 : f32
|
||||
memref.store %cst, %alloc[%arg1] : memref<16xf32>
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user