[mlir] Make LocalAliasAnalysis extesible

This is an alternative to https://reviews.llvm.org/D138761 . Instead of adding ad-hoc attributes to existing `LocalAliasAnalysis`, expose `aliasImpl` method so user can override it.

Differential Revision: https://reviews.llvm.org/D140348
This commit is contained in:
Ivan Butygin 2022-12-19 22:26:07 +01:00
parent 11e0500598
commit d42cb02448
4 changed files with 85 additions and 2 deletions

View File

@ -28,6 +28,10 @@ public:
/// Return the modify-reference behavior of `op` on `location`.
ModRefResult getModRef(Operation *op, Value location);
protected:
/// Given the two values, return their aliasing behavior.
virtual AliasResult aliasImpl(Value lhs, Value rhs);
};
} // namespace mlir

View File

@ -246,7 +246,7 @@ getAllocEffectFor(Value value,
}
/// Given the two values, return their aliasing behavior.
static AliasResult aliasImpl(Value lhs, Value rhs) {
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (lhs == rhs)
return AliasResult::MustAlias;
Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;

View File

@ -0,0 +1,15 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-alias-analysis-extending))' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s
// CHECK-LABEL: Testing : "restrict"
// CHECK-DAG: func.region0#0 <-> func.region0#1: NoAlias
// CHECK-DAG: view1#0 <-> view2#0: NoAlias
// CHECK-DAG: view1#0 <-> func.region0#0: MustAlias
// CHECK-DAG: view1#0 <-> func.region0#1: NoAlias
// CHECK-DAG: view2#0 <-> func.region0#0: NoAlias
// CHECK-DAG: view2#0 <-> func.region0#1: MustAlias
func.func @restrict(%arg: memref<?xf32>, %arg1: memref<?xf32> {local_alias_analysis.restrict}) attributes {test.ptr = "func"} {
%0 = memref.subview %arg[0][2][1] {test.ptr = "view1"} : memref<?xf32> to memref<2xf32>
%1 = memref.subview %arg1[0][2][1] {test.ptr = "view2"} : memref<?xf32> to memref<2xf32>
return
}

View File

@ -13,6 +13,8 @@
#include "TestAliasAnalysis.h"
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@ -148,6 +150,67 @@ struct TestAliasAnalysisModRefPass
};
} // namespace
//===----------------------------------------------------------------------===//
// Testing LocalAliasAnalysis extending
//===----------------------------------------------------------------------===//
/// Check if value is function argument.
static bool isFuncArg(Value val) {
auto blockArg = val.dyn_cast<BlockArgument>();
if (!blockArg)
return false;
return mlir::isa_and_nonnull<FunctionOpInterface>(
blockArg.getOwner()->getParentOp());
}
/// Check if value has "restrict" attribute. Value must be a function argument.
static bool isRestrict(Value val) {
auto blockArg = val.cast<BlockArgument>();
auto func =
mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp());
return !!func.getArgAttr(blockArg.getArgNumber(),
"local_alias_analysis.restrict");
}
namespace {
/// LocalAliasAnalysis extended to support "restrict" attreibute.
class LocalAliasAnalysisRestrict : public LocalAliasAnalysis {
protected:
AliasResult aliasImpl(Value lhs, Value rhs) override {
if (lhs == rhs)
return AliasResult::MustAlias;
// Assume no aliasing if both values are function arguments and any of them
// have restrict attr.
if (isFuncArg(lhs) && isFuncArg(rhs))
if (isRestrict(lhs) || isRestrict(rhs))
return AliasResult::NoAlias;
return LocalAliasAnalysis::aliasImpl(lhs, rhs);
}
};
/// This pass tests adding additional analysis impls to the AliasAnalysis.
struct TestAliasAnalysisExtendingPass
: public test::TestAliasAnalysisBase,
PassWrapper<TestAliasAnalysisExtendingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass)
StringRef getArgument() const final {
return "test-alias-analysis-extending";
}
StringRef getDescription() const final {
return "Test alias analysis extending.";
}
void runOnOperation() override {
AliasAnalysis aliasAnalysis(getOperation());
aliasAnalysis.addAnalysisImplementation(LocalAliasAnalysisRestrict());
runAliasAnalysisOnOperation(getOperation(), aliasAnalysis);
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass Registration
//===----------------------------------------------------------------------===//
@ -155,8 +218,9 @@ struct TestAliasAnalysisModRefPass
namespace mlir {
namespace test {
void registerTestAliasAnalysisPass() {
PassRegistration<TestAliasAnalysisPass>();
PassRegistration<TestAliasAnalysisExtendingPass>();
PassRegistration<TestAliasAnalysisModRefPass>();
PassRegistration<TestAliasAnalysisPass>();
}
} // namespace test
} // namespace mlir