[mlir] Add a generic data-flow analysis framework

This patch introduces a generic data-flow analysis framework to MLIR. The framework implements a fixed-point iteration algorithm and a dependency graph between lattice states and analysis. Lattice states and points are fully extensible to support highly-customizable analyses.

Reviewed By: phisiart, rriddle

Differential Revision: https://reviews.llvm.org/D126751
This commit is contained in:
Mogball 2022-06-13 21:54:52 +00:00
parent 3b54aa9eab
commit 9dea117283
9 changed files with 903 additions and 5 deletions

View File

@ -0,0 +1,454 @@
//===- DataFlowFramework.h - A generic framework for data-flow analysis ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a generic framework for writing data-flow analysis in MLIR.
// The framework consists of a solver, which runs the fixed-point iteration and
// manages analysis dependencies, and a data-flow analysis class used to
// implement specific analyses.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
#define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/TypeName.h"
#include <queue>
namespace mlir {
/// Forward declare the analysis state class.
class AnalysisState;
//===----------------------------------------------------------------------===//
// GenericProgramPoint
//===----------------------------------------------------------------------===//
/// Abstract class for generic program points. In classical data-flow analysis,
/// programs points represent positions in a program to which lattice elements
/// are attached. In sparse data-flow analysis, these can be SSA values, and in
/// dense data-flow analysis, these are the program points before and after
/// every operation.
///
/// In the general MLIR data-flow analysis framework, program points are an
/// extensible concept. Program points are uniquely identifiable objects to
/// which analysis states can be attached. The semantics of program points are
/// defined by the analyses that specify their transfer functions.
///
/// Program points are implemented using MLIR's storage uniquer framework and
/// type ID system to provide RTTI.
class GenericProgramPoint : public StorageUniquer::BaseStorage {
public:
virtual ~GenericProgramPoint();
/// Get the abstract program point's type identifier.
TypeID getTypeID() const { return typeID; }
/// Get a derived source location for the program point.
virtual Location getLoc() const = 0;
/// Print the program point.
virtual void print(raw_ostream &os) const = 0;
protected:
/// Create an abstract program point with type identifier.
explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
private:
/// The type identifier of the program point.
TypeID typeID;
};
//===----------------------------------------------------------------------===//
// GenericProgramPointBase
//===----------------------------------------------------------------------===//
/// Base class for generic program points based on a concrete program point
/// type and a content key. This class defines the common methods required for
/// operability with the storage uniquer framework.
///
/// The provided key type uniquely identifies the concrete program point
/// instance and are the data members of the class.
template <typename ConcreteT, typename Value>
class GenericProgramPointBase : public GenericProgramPoint {
public:
/// The concrete key type used by the storage uniquer. This class is uniqued
/// by its contents.
using KeyTy = Value;
/// Alias for the base class.
using Base = GenericProgramPointBase<ConcreteT, Value>;
/// Construct an instance of the program point using the provided value and
/// the type ID of the concrete type.
template <typename ValueT>
explicit GenericProgramPointBase(ValueT &&value)
: GenericProgramPoint(TypeID::get<ConcreteT>()),
value(std::forward<ValueT>(value)) {}
/// Get a uniqued instance of this program point class with the given
/// arguments.
template <typename... Args>
static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
}
/// Allocate space for a program point and construct it in-place.
template <typename ValueT>
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
ValueT &&value) {
return new (alloc.allocate<ConcreteT>())
ConcreteT(std::forward<ValueT>(value));
}
/// Two program points are equal if their values are equal.
bool operator==(const Value &value) const { return this->value == value; }
/// Provide LLVM-style RTTI using type IDs.
static bool classof(const GenericProgramPoint *point) {
return point->getTypeID() == TypeID::get<ConcreteT>();
}
/// Get the contents of the program point.
const Value &getValue() const { return value; }
private:
/// The program point value.
Value value;
};
//===----------------------------------------------------------------------===//
// ProgramPoint
//===----------------------------------------------------------------------===//
/// Fundamental IR components are supported as first-class program points.
struct ProgramPoint : public PointerUnion<GenericProgramPoint *, Operation *,
Value, Block *, Region *> {
using ParentTy = PointerUnion<GenericProgramPoint *, Operation *, Value,
Block *, Region *>;
/// Inherit constructors.
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
/// Print the program point.
void print(raw_ostream &os) const;
/// Get the source location of the program point.
Location getLoc() const;
};
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;
//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
/// The general data-flow analysis solver. This class is responsible for
/// orchestrating child data-flow analyses, running the fixed-point iteration
/// algorithm, managing analysis state and program point memory, and tracking
/// dependencies beteen analyses, program points, and analysis states.
///
/// Steps to run a data-flow analysis:
///
/// 1. Load and initialize children analyses. Children analyses are instantiated
/// in the solver and initialized, building their dependency relations.
/// 2. Configure and run the analysis. The solver invokes the children analyses
/// according to their dependency relations until a fixed point is reached.
/// 3. Query analysis state results from the solver.
///
/// TODO: Optimize the internal implementation of the solver.
class DataFlowSolver {
public:
/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
/// Initialize the children analyses starting from the provided top-level
/// operation and run the analysis until fixpoint.
LogicalResult initializeAndRun(Operation *top);
/// Lookup an analysis state for the given program point. Returns null if one
/// does not exist.
template <typename StateT, typename PointT>
const StateT *lookupState(PointT point) const {
auto it = analysisStates.find({point, TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
}
/// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
template <typename PointT, typename... Args>
PointT *getProgramPoint(Args &&...args) {
return PointT::get(uniquer, std::forward<Args>(args)...);
}
/// A work item on the solver queue is a program point, child analysis pair.
/// Each item is processed by invoking the child analysis at the program
/// point.
using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>;
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
protected:
/// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
template <typename StateT, typename PointT>
StateT *getOrCreateState(PointT point);
/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
/// Add a dependency to an analysis state on a child analysis and program
/// point. If the state is updated, the child analysis must be invoked on the
/// given program point again.
void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
ProgramPoint point);
private:
/// The solver's work queue. Work items can be inserted to the front of the
/// queue to be processed greedily, speeding up computations that otherwise
/// quickly degenerate to quadratic due to propagation of state updates.
std::queue<WorkItem> worklist;
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
/// The storage uniquer instance that owns the memory of the allocated program
/// points.
StorageUniquer uniquer;
/// A type-erased map of program points to associated analysis states for
/// first-class program points.
DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
analysisStates;
/// Allow the base child analysis class to access the internals of the solver.
friend class DataFlowAnalysis;
};
//===----------------------------------------------------------------------===//
// AnalysisState
//===----------------------------------------------------------------------===//
/// Base class for generic analysis states. Analysis states contain data-flow
/// information that are attached to program points and which evolve as the
/// analysis iterates.
///
/// This class places no restrictions on the semantics of analysis states beyond
/// these requirements.
///
/// 1. Querying the state of a program point prior to visiting that point
/// results in uninitialized state. Analyses must be aware of unintialized
/// states.
/// 2. Analysis states can reach fixpoints, where subsequent updates will never
/// trigger a change in the state.
/// 3. Analysis states that are uninitialized can be forcefully initialized to a
/// default value.
class AnalysisState {
public:
virtual ~AnalysisState();
/// Create the analysis state at the given program point.
AnalysisState(ProgramPoint point) : point(point) {}
/// Returns true if the analysis state is uninitialized.
virtual bool isUninitialized() const = 0;
/// Force an uninitialized analysis state to initialize itself with a default
/// value.
virtual ChangeResult defaultInitialize() = 0;
/// Print the contents of the analysis state.
virtual void print(raw_ostream &os) const = 0;
protected:
/// This function is called by the solver when the analysis state is updated
/// to optionally enqueue more work items. For example, if a state tracks
/// dependents through the IR (e.g. use-def chains), this function can be
/// implemented to push those dependents on the worklist.
virtual void onUpdate(DataFlowSolver *solver) const {}
/// The dependency relations originating from this analysis state. An entry
/// `state -> (analysis, point)` is created when `analysis` queries `state`
/// when updating `point`.
///
/// When this state is updated, all dependent child analysis invocations are
/// pushed to the back of the queue. Use a `SetVector` to keep the analysis
/// deterministic.
///
/// Store the dependents on the analysis state for efficiency.
SetVector<DataFlowSolver::WorkItem> dependents;
/// The program point to which the state belongs.
ProgramPoint point;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analysis state.
StringRef debugName;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Allow the framework to access the dependents.
friend class DataFlowSolver;
};
//===----------------------------------------------------------------------===//
// DataFlowAnalysis
//===----------------------------------------------------------------------===//
/// Base class for all data-flow analyses. A child analysis is expected to build
/// an initial dependency graph (and optionally provide an initial state) when
/// initialized and define transfer functions when visiting program points.
///
/// In classical data-flow analysis, the dependency graph is fixed and analyses
/// define explicit transfer functions between input states and output states.
/// In this framework, however, the dependency graph can change during the
/// analysis, and transfer functions are opaque such that the solver doesn't
/// know what states calling `visit` on an analysis will be updated. This allows
/// multiple analyses to plug in and provide values for the same state.
///
/// Generally, when an analysis queries an uninitialized state, it is expected
/// to "bail out", i.e., not provide any updates. When the value is initialized,
/// the solver will re-invoke the analysis. If the solver exhausts its worklist,
/// however, and there are still uninitialized states, the solver "nudges" the
/// analyses by default-initializing those states.
class DataFlowAnalysis {
public:
virtual ~DataFlowAnalysis();
/// Create an analysis with a reference to the parent solver.
explicit DataFlowAnalysis(DataFlowSolver &solver);
/// Initialize the analysis from the provided top-level operation by building
/// an initial dependency graph between all program points of interest. This
/// can be implemented by calling `visit` on all program points of interest
/// below the top-level operation.
///
/// An analysis can optionally provide initial values to certain analysis
/// states to influence the evolution of the analysis.
virtual LogicalResult initialize(Operation *top) = 0;
/// Visit the given program point. This function is invoked by the solver on
/// this analysis with a given program point when a dependent analysis state
/// is updated. The function is similar to a transfer function; it queries
/// certain analysis states and sets other states.
///
/// The function is expected to create dependencies on queried states and
/// propagate updates on changed states. A dependency can be created by
/// calling `addDependency` between the input state and a program point,
/// indicating that, if the state is updated, the solver should invoke `solve`
/// on the program point. The dependent point does not have to be the same as
/// the provided point. An update to a state is propagated by calling
/// `propagateIfChange` on the state. If the state has changed, then all its
/// dependents are placed on the worklist.
///
/// The dependency graph does not need to be static. Each invocation of
/// `visit` can add new dependencies, but these dependecies will not be
/// dynamically added to the worklist because the solver doesn't know what
/// will provide a value for then.
virtual LogicalResult visit(ProgramPoint point) = 0;
protected:
/// Create a dependency between the given analysis state and program point
/// on this analysis.
void addDependency(AnalysisState *state, ProgramPoint point);
/// Propagate an update to a state if it changed.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
/// Register a custom program point class.
template <typename PointT>
void registerPointKind() {
solver.uniquer.registerParametricStorageType<PointT>();
}
/// Get or create a custom program point.
template <typename PointT, typename... Args>
PointT *getProgramPoint(Args &&...args) {
return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
}
/// Get the analysis state assiocated with the program point. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
template <typename StateT, typename PointT>
StateT *getOrCreate(PointT point) {
return solver.getOrCreateState<StateT>(point);
}
/// Get a read-only analysis state for the given point and create a dependency
/// on `dependent`. If the return state is updated elsewhere, this analysis is
/// re-invoked on the dependent.
template <typename StateT, typename PointT>
const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
StateT *state = getOrCreate<StateT>(point);
addDependency(state, dependent);
return state;
}
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analyis.
StringRef debugName;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
private:
/// The parent data-flow solver.
DataFlowSolver &solver;
/// Allow the data-flow solver to access the internals of this class.
friend class DataFlowSolver;
};
template <typename AnalysisT, typename... Args>
AnalysisT *DataFlowSolver::load(Args &&...args) {
childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
childAnalyses.back().get()->debugName = llvm::getTypeName<AnalysisT>();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return static_cast<AnalysisT *>(childAnalyses.back().get());
}
template <typename StateT, typename PointT>
StateT *DataFlowSolver::getOrCreateState(PointT point) {
std::unique_ptr<AnalysisState> &state =
analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
if (!state) {
state = std::unique_ptr<StateT>(new StateT(point));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state->debugName = llvm::getTypeName<StateT>();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
return static_cast<StateT *>(state.get());
}
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
state.print(os);
return os;
}
inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
point.print(os);
return os;
}
} // end namespace mlir
namespace llvm {
/// Allow hashing of program points.
template <>
struct DenseMapInfo<mlir::ProgramPoint>
: public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
} // end namespace llvm
#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H

View File

@ -16,6 +16,7 @@ add_mlir_library(MLIRAnalysis
BufferViewFlowAnalysis.cpp
CallGraph.cpp
DataFlowAnalysis.cpp
DataFlowFramework.cpp
DataLayoutAnalysis.cpp
IntRangeAnalysis.cpp
Liveness.cpp

View File

@ -0,0 +1,161 @@
//===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlowFramework.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "dataflow"
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
#define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
#else
#define DATAFLOW_DEBUG(X)
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
using namespace mlir;
//===----------------------------------------------------------------------===//
// GenericProgramPoint
//===----------------------------------------------------------------------===//
GenericProgramPoint::~GenericProgramPoint() = default;
//===----------------------------------------------------------------------===//
// AnalysisState
//===----------------------------------------------------------------------===//
AnalysisState::~AnalysisState() = default;
//===----------------------------------------------------------------------===//
// ProgramPoint
//===----------------------------------------------------------------------===//
void ProgramPoint::print(raw_ostream &os) const {
if (isNull()) {
os << "<NULL POINT>";
return;
}
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
return programPoint->print(os);
if (auto *op = dyn_cast<Operation *>())
return op->print(os);
if (auto value = dyn_cast<Value>())
return value.print(os);
if (auto *block = dyn_cast<Block *>())
return block->print(os);
auto *region = get<Region *>();
os << "{\n";
for (Block &block : *region) {
block.print(os);
os << "\n";
}
os << "}";
}
Location ProgramPoint::getLoc() const {
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
return programPoint->getLoc();
if (auto *op = dyn_cast<Operation *>())
return op->getLoc();
if (auto value = dyn_cast<Value>())
return value.getLoc();
if (auto *block = dyn_cast<Block *>())
return block->getParent()->getLoc();
return get<Region *>()->getLoc();
}
//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
DATAFLOW_DEBUG(llvm::dbgs()
<< "Priming analysis: " << analysis.debugName << "\n");
if (failed(analysis.initialize(top)))
return failure();
}
// Run the analysis until fixpoint.
ProgramPoint point;
DataFlowAnalysis *analysis;
do {
// Exhaust the worklist.
while (!worklist.empty()) {
std::tie(point, analysis) = worklist.front();
worklist.pop();
DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
<< "' on: " << point << "\n");
if (failed(analysis->visit(point)))
return failure();
}
// "Nudge" the state of the analysis by forcefully initializing states that
// are still uninitialized. All uninitialized states in the graph can be
// initialized in any order because the analysis reached fixpoint, meaning
// that there are no work items that would have further nudged the analysis.
for (AnalysisState &state :
llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
if (!state.isUninitialized())
continue;
DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
<< " of " << state.point << "\n");
propagateIfChanged(&state, state.defaultInitialize());
}
// Iterate until all states are in some initialized state and the worklist
// is exhausted.
} while (!worklist.empty());
return success();
}
void DataFlowSolver::propagateIfChanged(AnalysisState *state,
ChangeResult changed) {
if (changed == ChangeResult::Change) {
DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
<< " of " << state->point << "\n"
<< "Value: " << *state << "\n");
for (const WorkItem &item : state->dependents)
enqueue(item);
state->onUpdate(this);
}
}
void DataFlowSolver::addDependency(AnalysisState *state,
DataFlowAnalysis *analysis,
ProgramPoint point) {
auto inserted = state->dependents.insert({point, analysis});
(void)inserted;
DATAFLOW_DEBUG({
if (inserted) {
llvm::dbgs() << "Creating dependency between " << state->debugName
<< " of " << state->point << "\nand " << analysis->debugName
<< " on " << point << "\n";
}
});
}
//===----------------------------------------------------------------------===//
// DataFlowAnalysis
//===----------------------------------------------------------------------===//
DataFlowAnalysis::~DataFlowAnalysis() = default;
DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
solver.addDependency(state, this, point);
}
void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
ChangeResult changed) {
solver.propagateIfChanged(state, changed);
}

View File

@ -0,0 +1,95 @@
// RUN: mlir-opt -split-input-file -pass-pipeline='func.func(test-foo-analysis)' %s 2>&1 | FileCheck %s
// CHECK-LABEL: function: @test_default_init
func.func @test_default_init() -> () {
// CHECK: a -> 0
"test.foo"() {tag = "a"} : () -> ()
return
}
// -----
// CHECK-LABEL: function: @test_one_join
func.func @test_one_join() -> () {
// CHECK: a -> 0
"test.foo"() {tag = "a"} : () -> ()
// CHECK: b -> 1
"test.foo"() {tag = "b", foo = 1 : ui64} : () -> ()
return
}
// -----
// CHECK-LABEL: function: @test_two_join
func.func @test_two_join() -> () {
// CHECK: a -> 0
"test.foo"() {tag = "a"} : () -> ()
// CHECK: b -> 1
"test.foo"() {tag = "b", foo = 1 : ui64} : () -> ()
// CHECK: c -> 0
"test.foo"() {tag = "c", foo = 1 : ui64} : () -> ()
return
}
// -----
// CHECK-LABEL: function: @test_fork
func.func @test_fork() -> () {
// CHECK: init -> 1
"test.branch"() [^bb0, ^bb1] {tag = "init", foo = 1 : ui64} : () -> ()
^bb0:
// CHECK: a -> 3
"test.branch"() [^bb2] {tag = "a", foo = 2 : ui64} : () -> ()
^bb1:
// CHECK: b -> 5
"test.branch"() [^bb2] {tag = "b", foo = 4 : ui64} : () -> ()
^bb2:
// CHECK: end -> 6
"test.foo"() {tag = "end"} : () -> ()
return
}
// -----
// CHECK-LABEL: function: @test_simple_loop
func.func @test_simple_loop() -> () {
// CHECK: init -> 1
"test.branch"() [^bb0] {tag = "init", foo = 1 : ui64} : () -> ()
^bb0:
// CHECK: a -> 1
"test.foo"() {tag = "a", foo = 3 : ui64} : () -> ()
"test.branch"() [^bb0, ^bb1] : () -> ()
^bb1:
// CHECK: end -> 3
"test.foo"() {tag = "end"} : () -> ()
return
}
// -----
// CHECK-LABEL: function: @test_double_loop
func.func @test_double_loop() -> () {
// CHECK: init -> 2
"test.branch"() [^bb0] {tag = "init", foo = 2 : ui64} : () -> ()
^bb0:
// CHECK: a -> 1
"test.foo"() {tag = "a", foo = 3 : ui64} : () -> ()
"test.branch"() [^bb0, ^bb1] : () -> ()
^bb1:
// CHECK: b -> 4
"test.foo"() {tag = "b", foo = 5 : ui64} : () -> ()
"test.branch"() [^bb0, ^bb2] : () -> ()
^bb2:
// CHECK: end -> 4
"test.foo"() {tag = "end"} : () -> ()
return
}

View File

@ -3,6 +3,7 @@ add_mlir_library(MLIRTestAnalysis
TestAliasAnalysis.cpp
TestCallGraph.cpp
TestDataFlow.cpp
TestDataFlowFramework.cpp
TestLiveness.cpp
TestMatchReduction.cpp
TestMemRefBoundCheck.cpp

View File

@ -0,0 +1,188 @@
//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This analysis state represents an integer that is XOR'd with other states.
class FooState : public AnalysisState {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
using AnalysisState::AnalysisState;
/// Default-initialize the state to zero.
ChangeResult defaultInitialize() override { return join(0); }
/// Returns true if the state is uninitialized.
bool isUninitialized() const override { return !state; }
/// Print the integer value or "none" if uninitialized.
void print(raw_ostream &os) const override {
if (state)
os << *state;
else
os << "none";
}
/// Join the state with another. If either is unintialized, take the
/// initialized value. Otherwise, XOR the integer values.
ChangeResult join(const FooState &rhs) {
if (rhs.isUninitialized())
return ChangeResult::NoChange;
return join(*rhs.state);
}
ChangeResult join(uint64_t value) {
if (isUninitialized()) {
state = value;
return ChangeResult::Change;
}
uint64_t before = *state;
state = before ^ value;
return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
}
/// Set the value of the state directly.
ChangeResult set(const FooState &rhs) {
if (state == rhs.state)
return ChangeResult::NoChange;
state = rhs.state;
return ChangeResult::Change;
}
/// Returns the integer value of the state.
uint64_t getValue() const { return *state; }
private:
/// An optional integer value.
Optional<uint64_t> state;
};
/// This analysis computes `FooState` across operations and control-flow edges.
/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
/// with the value before the operation.
class FooAnalysis : public DataFlowAnalysis {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override;
LogicalResult visit(ProgramPoint point) override;
private:
void visitBlock(Block *block);
void visitOperation(Operation *op);
};
struct TestFooAnalysisPass
: public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
StringRef getArgument() const override { return "test-foo-analysis"; }
void runOnOperation() override;
};
} // namespace
LogicalResult FooAnalysis::initialize(Operation *top) {
if (top->getNumRegions() != 1)
return top->emitError("expected a single region top-level op");
// Initialize the top-level state.
getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
// Visit all nested blocks and operations.
for (Block &block : top->getRegion(0)) {
visitBlock(&block);
for (Operation &op : block) {
if (op.getNumRegions())
return op.emitError("unexpected op with regions");
visitOperation(&op);
}
}
return success();
}
LogicalResult FooAnalysis::visit(ProgramPoint point) {
if (auto *op = point.dyn_cast<Operation *>()) {
visitOperation(op);
return success();
}
if (auto *block = point.dyn_cast<Block *>()) {
visitBlock(block);
return success();
}
return emitError(point.getLoc(), "unknown point kind");
}
void FooAnalysis::visitBlock(Block *block) {
if (block->isEntryBlock()) {
// This is the initial state. Let the framework default-initialize it.
return;
}
FooState *state = getOrCreate<FooState>(block);
ChangeResult result = ChangeResult::NoChange;
for (Block *pred : block->getPredecessors()) {
// Join the state at the terminators of all predecessors.
const FooState *predState =
getOrCreateFor<FooState>(block, pred->getTerminator());
result |= state->join(*predState);
}
propagateIfChanged(state, result);
}
void FooAnalysis::visitOperation(Operation *op) {
FooState *state = getOrCreate<FooState>(op);
ChangeResult result = ChangeResult::NoChange;
// Copy the state across the operation.
const FooState *prevState;
if (Operation *prev = op->getPrevNode())
prevState = getOrCreateFor<FooState>(op, prev);
else
prevState = getOrCreateFor<FooState>(op, op->getBlock());
result |= state->set(*prevState);
// Modify the state with the attribute, if specified.
if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
uint64_t value = attr.getUInt();
result |= state->join(value);
}
propagateIfChanged(state, result);
}
void TestFooAnalysisPass::runOnOperation() {
func::FuncOp func = getOperation();
DataFlowSolver solver;
solver.load<FooAnalysis>();
if (failed(solver.initializeAndRun(func)))
return signalPassFailure();
raw_ostream &os = llvm::errs();
os << "function: @" << func.getSymName() << "\n";
func.walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
const FooState *state = solver.lookupState<FooState>(op);
assert(state && !state->isUninitialized());
os << tag.getValue() << " -> " << state->getValue() << "\n";
});
}
namespace mlir {
namespace test {
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
} // namespace test
} // namespace mlir

View File

@ -77,6 +77,7 @@ void registerTestDiagnosticsPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIntRangeInference();
@ -175,6 +176,7 @@ void registerTestPasses() {
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIntRangeInference();

View File

@ -5787,17 +5787,12 @@ cc_library(
"lib/Analysis/*/*.cpp",
"lib/Analysis/*/*.h",
],
exclude = [
"lib/Analysis/Vector*.cpp",
"lib/Analysis/Vector*.h",
],
),
hdrs = glob(
[
"include/mlir/Analysis/*.h",
"include/mlir/Analysis/*/*.h",
],
exclude = ["include/mlir/Analysis/Vector*.h"],
),
includes = ["include"],
deps = [

View File

@ -26,6 +26,7 @@ cc_library(
"//mlir:AffineAnalysis",
"//mlir:AffineDialect",
"//mlir:Analysis",
"//mlir:FuncDialect",
"//mlir:IR",
"//mlir:MemRefDialect",
"//mlir:Pass",