[mlir][sparse] fix crash when using pure constant index in indexing mapping (fixes #61530)

To address https://github.com/llvm/llvm-project/issues/61530

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D146563
This commit is contained in:
Peiming Liu 2023-03-21 20:47:47 +00:00
parent e8ad2a051c
commit 2b21327fee
9 changed files with 81 additions and 16 deletions

View File

@ -254,6 +254,9 @@ public:
/// }
///
/// to filter out coordinates that are not equal to the affine expression.
///
/// The maxLvlRank specifies the max level rank of all inputs/output tensors.
/// It is used to pre-allocate sufficient memory for internal storage.
//
// TODO: we want to make the filter loop more efficient in the future,
// e.g., by avoiding scanning the full list of stored coordinates (keeping
@ -264,7 +267,7 @@ public:
// gave the number of input tensors, instead of the current number of
// input+output tensors.
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops);
unsigned numFilterLoops, unsigned maxLvlRank);
/// Constructs a new tensor expression, and returns its identifier.
/// The type of the `e0` argument varies according to the value of the

View File

@ -51,12 +51,12 @@ static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops,
unsigned numFilterLoops)
unsigned numFilterLoops, unsigned maxRank)
: linalgOp(linop), sparseOptions(opts),
latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(),
topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId),
redCustom(kInvalidId), redValidLexInsert() {}
latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
redExp(kInvalidId), redCustom(kInvalidId), redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.

View File

@ -38,7 +38,8 @@ public:
/// passed around during sparsification for bookkeeping
/// together with some consistency asserts.
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops, unsigned numFilterLoops);
unsigned numTensors, unsigned numLoops, unsigned numFilterLoops,
unsigned maxRank);
//
// General methods.

View File

@ -288,12 +288,18 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
coordinatesBuffers[tid].assign(lvlRank, Value());
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());
dependentLvlMap[tid].assign(lvlRank,
std::vector<std::pair<TensorId, Level>>());
if (dimGetter)
for (Level l = 0; l < lvlRank; l++)
if (dimGetter) {
auto reassoc = collapseReassoc[tid];
Level dstRank = reassoc ? reassoc.size() : lvlRank;
for (Level l = 0; l < dstRank; l++) {
dependentLvlMap[tid][l] = dimGetter(tid, l);
// TODO: View-base collapse and dependent index reduction are not
// compatible right now.
assert(!reassoc || dependentLvlMap[tid][l].empty());
}
}
}
// Construct the inverse of the `topSort` from the sparsifier.

View File

@ -1811,10 +1811,24 @@ public:
// possible, we can even intermix slice-based and filter-loop based codegen.
bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;
// If we have indexing map like (d0) -> (0, d0), there might be more
// levels then loops because of the constant index, that means we can not
// use numLoops as the upper bound for ranks of all tensors.
// TODO: Constant indices are currently not support on sparse tensor, but
// are allowed in non-annotated dense tensor. Support it, it would be
// required for sparse tensor slice rank reducing too.
Level maxLvlRank = 0;
for (auto operand : op.getOperands()) {
if (auto rtp = operand.getType().dyn_cast<RankedTensorType>()) {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
// If we uses slice based algorithm for affine index, we do not need filter
// loop.
CodegenEnv env(op, options, numTensors, numLoops,
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops);
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
maxLvlRank);
// Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.

View File

@ -210,7 +210,7 @@ LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i,
}
Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops)
unsigned numFilterLoops, unsigned maxLvlRank)
: outTensor(numInputOutputTensors - 1),
syntheticTensor(numInputOutputTensors),
numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
@ -220,11 +220,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
loopToLvl(numTensors,
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
lvlToLoop(numTensors,
std::vector<std::optional<LoopId>>(numLoops, std::nullopt)),
std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
loopToDependencies(numLoops, std::vector<std::optional<Level>>(
numTensors, std::nullopt)),
levelToDependentIdx(numTensors, std::vector<std::vector<LoopId>>(
numLoops, std::vector<LoopId>())),
maxLvlRank, std::vector<LoopId>())),
loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,41 @@
// Reported by https://github.com/llvm/llvm-project/issues/61530
// RUN: mlir-opt %s -sparsification | FileCheck %s
#map1 = affine_map<(d0) -> (0, d0)>
#map2 = affine_map<(d0) -> (d0)>
#SpVec = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<1x77xi1>,
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<1x77xi1>) -> tensor<77xi1, #{{.*}}> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 77 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<77xi1, #{{.*}}>
// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<1x77xi1>
// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<1x77xi1>
// CHECK: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_5]]) -> (tensor<77xi1, #{{.*}}>) {
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
// CHECK: %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
// CHECK: scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>
// CHECK: return %[[VAL_15]] : tensor<77xi1, #{{.*}}>
// CHECK: }
func.func @main(%arg0: tensor<1x77xi1>, %arg1: tensor<1x77xi1>) -> tensor<77xi1, #SpVec> {
%0 = bufferization.alloc_tensor() : tensor<77xi1, #SpVec>
%1 = linalg.generic {
indexing_maps = [#map1, #map1, #map2],
iterator_types = ["parallel"]}
ins(%arg0, %arg1 : tensor<1x77xi1>, tensor<1x77xi1>)
outs(%0 : tensor<77xi1, #SpVec>) {
^bb0(%in: i1, %in_0: i1, %out: i1):
%2 = arith.addi %in, %in_0 : i1
linalg.yield %2 : i1
} -> tensor<77xi1, #SpVec>
return %1 : tensor<77xi1, #SpVec>
}

View File

@ -496,4 +496,3 @@ func.func @mul_const_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
} -> tensor<32x16xf64>
return %0 : tensor<32x16xf64>
}

View File

@ -128,7 +128,8 @@ class MergerTestBase : public ::testing::Test {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: numTensors(numTensors), numLoops(numLoops),
merger(numTensors, numLoops, /*numFilterLoops=*/0) {}
merger(numTensors, numLoops, /*numFilterLoops=*/0,
/*maxRank=*/numLoops) {}
///
/// Expression construction helpers.