ea64828a10
This commit refactors the expected form of native constraint and rewrite functions, and greatly reduces the necessary user complexity required when defining a native function. Namely, this commit adds in automatic processing of the necessary PDLValue glue code, and allows for users to define constraint/rewrite functions using the C++ types that they actually want to use. As an example, lets see a simple example rewrite defined today: ``` static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results, ArrayRef<PDLValue> args) { ValueRange operandValues = args[0].cast<ValueRange>(); TypeRange typeValues = args[1].cast<TypeRange>(); ... // Create an operation at some point and pass it back to PDL. Operation *op = rewriter.create<SomeOp>(...); results.push_back(op); } ``` After this commit, that same rewrite could be defined as: ``` static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues, TypeRange typeValues) { ... // Create an operation at some point and pass it back to PDL. return rewriter.create<SomeOp>(...); } ``` Differential Revision: https://reviews.llvm.org/D122086
125 lines
5.3 KiB
C++
125 lines
5.3 KiB
C++
//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
/// Custom constraint invoked from PDL.
|
|
static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
|
|
Operation *rootOp) {
|
|
return success(rootOp->getName().getStringRef() == "test.op");
|
|
}
|
|
static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
|
|
Operation *root,
|
|
Operation *rootCopy) {
|
|
return customSingleEntityConstraint(rewriter, rootCopy);
|
|
}
|
|
static LogicalResult customMultiEntityVariadicConstraint(
|
|
PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
|
|
if (operandValues.size() != 2 || typeValues.size() != 2)
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
// Custom creator invoked from PDL.
|
|
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
|
|
return rewriter.create(OperationState(op->getLoc(), "test.success"));
|
|
}
|
|
static auto customVariadicResultCreate(PatternRewriter &rewriter,
|
|
Operation *root) {
|
|
return std::make_pair(root->getOperands(), root->getOperands().getTypes());
|
|
}
|
|
static Type customCreateType(PatternRewriter &rewriter) {
|
|
return rewriter.getF32Type();
|
|
}
|
|
static std::string customCreateStrAttr(PatternRewriter &rewriter) {
|
|
return "test.str";
|
|
}
|
|
|
|
/// Custom rewriter invoked from PDL.
|
|
static void customRewriter(PatternRewriter &rewriter, Operation *root,
|
|
Value input) {
|
|
rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
|
|
input);
|
|
rewriter.eraseOp(root);
|
|
}
|
|
|
|
namespace {
|
|
struct TestPDLByteCodePass
|
|
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
|
|
|
|
StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
|
|
StringRef getDescription() const final {
|
|
return "Test PDL ByteCode functionality";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
// Mark the pdl_interp dialect as a dependent. This is needed, because we
|
|
// create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
|
|
registry.insert<pdl_interp::PDLInterpDialect>();
|
|
}
|
|
void runOnOperation() final {
|
|
ModuleOp module = getOperation();
|
|
|
|
// The test cases are encompassed via two modules, one containing the
|
|
// patterns and one containing the operations to rewrite.
|
|
ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
|
|
StringAttr::get(module->getContext(), "patterns"));
|
|
ModuleOp irModule = module.lookupSymbol<ModuleOp>(
|
|
StringAttr::get(module->getContext(), "ir"));
|
|
if (!patternModule || !irModule)
|
|
return;
|
|
|
|
RewritePatternSet patternList(module->getContext());
|
|
|
|
// Register ahead of time to test when functions are registered without a
|
|
// pattern.
|
|
patternList.getPDLPatterns().registerConstraintFunction(
|
|
"multi_entity_constraint", customMultiEntityConstraint);
|
|
patternList.getPDLPatterns().registerConstraintFunction(
|
|
"single_entity_constraint", customSingleEntityConstraint);
|
|
|
|
// Process the pattern module.
|
|
patternModule.getOperation()->remove();
|
|
PDLPatternModule pdlPattern(patternModule);
|
|
|
|
// Note: This constraint was already registered, but we re-register here to
|
|
// ensure that duplication registration is allowed (the duplicate mapping
|
|
// will be ignored). This tests that we support separating the registration
|
|
// of library functions from the construction of patterns, and also that we
|
|
// allow multiple patterns to depend on the same library functions (without
|
|
// asserting/crashing).
|
|
pdlPattern.registerConstraintFunction("multi_entity_constraint",
|
|
customMultiEntityConstraint);
|
|
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
|
customMultiEntityVariadicConstraint);
|
|
pdlPattern.registerRewriteFunction("creator", customCreate);
|
|
pdlPattern.registerRewriteFunction("var_creator",
|
|
customVariadicResultCreate);
|
|
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
|
|
pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
|
|
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
|
|
patternList.add(std::move(pdlPattern));
|
|
|
|
// Invoke the pattern driver with the provided patterns.
|
|
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
|
|
std::move(patternList));
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
|
|
} // namespace test
|
|
} // namespace mlir
|