diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 991c920c1739..8b1e91ae8df5 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -254,6 +254,9 @@ public: /// } /// /// to filter out coordinates that are not equal to the affine expression. + /// + /// The maxLvlRank specifies the max level rank of all inputs/output tensors. + /// It is used to pre-allocate sufficient memory for internal storage. // // TODO: we want to make the filter loop more efficient in the future, // e.g., by avoiding scanning the full list of stored coordinates (keeping @@ -264,7 +267,7 @@ public: // gave the number of input tensors, instead of the current number of // input+output tensors. Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, - unsigned numFilterLoops); + unsigned numFilterLoops, unsigned maxLvlRank); /// Constructs a new tensor expression, and returns its identifier. /// The type of the `e0` argument varies according to the value of the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp index f326d5b950a3..974c86d1fab5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -51,12 +51,12 @@ static void sortArrayBasedOnOrder(std::vector &target, CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, - unsigned numFilterLoops) + unsigned numFilterLoops, unsigned maxRank) : linalgOp(linop), sparseOptions(opts), - latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), - topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), - expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId), - redCustom(kInvalidId), redValidLexInsert() {} + latticeMerger(numTensors, numLoops, numFilterLoops, maxRank), + loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), + insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), + redExp(kInvalidId), redCustom(kInvalidId), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h index 8c6a7bd6433d..0041ad0a272c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -38,7 +38,8 @@ public: /// passed around during sparsification for bookkeeping /// together with some consistency asserts. CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, - unsigned numTensors, unsigned numLoops, unsigned numFilterLoops); + unsigned numTensors, unsigned numLoops, unsigned numFilterLoops, + unsigned maxRank); // // General methods. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index 459a1b38e03d..cae92c34e258 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -288,12 +288,18 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, coordinatesBuffers[tid].assign(lvlRank, Value()); sliceOffsets[tid].assign(lvlRank, Value()); sliceStrides[tid].assign(lvlRank, Value()); - dependentLvlMap[tid].assign(lvlRank, std::vector>()); - if (dimGetter) - for (Level l = 0; l < lvlRank; l++) + if (dimGetter) { + auto reassoc = collapseReassoc[tid]; + Level dstRank = reassoc ? reassoc.size() : lvlRank; + for (Level l = 0; l < dstRank; l++) { dependentLvlMap[tid][l] = dimGetter(tid, l); + // TODO: View-base collapse and dependent index reduction are not + // compatible right now. + assert(!reassoc || dependentLvlMap[tid][l].empty()); + } + } } // Construct the inverse of the `topSort` from the sparsifier. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 63228531fcf0..f760244d59d8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1811,10 +1811,24 @@ public: // possible, we can even intermix slice-based and filter-loop based codegen. bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0; + // If we have indexing map like (d0) -> (0, d0), there might be more + // levels then loops because of the constant index, that means we can not + // use numLoops as the upper bound for ranks of all tensors. + // TODO: Constant indices are currently not support on sparse tensor, but + // are allowed in non-annotated dense tensor. Support it, it would be + // required for sparse tensor slice rank reducing too. + Level maxLvlRank = 0; + for (auto operand : op.getOperands()) { + if (auto rtp = operand.getType().dyn_cast()) { + maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); + } + } + // If we uses slice based algorithm for affine index, we do not need filter // loop. CodegenEnv env(op, options, numTensors, numLoops, - /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops); + /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops, + maxLvlRank); // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 0691d2554f43..9b39fd04d25e 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -210,7 +210,7 @@ LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i, } Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, - unsigned numFilterLoops) + unsigned numFilterLoops, unsigned maxLvlRank) : outTensor(numInputOutputTensors - 1), syntheticTensor(numInputOutputTensors), numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), @@ -220,11 +220,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, loopToLvl(numTensors, std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, - std::vector>(numLoops, std::nullopt)), + std::vector>(maxLvlRank, std::nullopt)), loopToDependencies(numLoops, std::vector>( numTensors, std::nullopt)), levelToDependentIdx(numTensors, std::vector>( - numLoops, std::vector())), + maxLvlRank, std::vector())), loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir new file mode 100644 index 000000000000..cbd48b06afaa --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir @@ -0,0 +1,41 @@ +// Reported by https://github.com/llvm/llvm-project/issues/61530 + +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#map1 = affine_map<(d0) -> (0, d0)> +#map2 = affine_map<(d0) -> (d0)> + +#SpVec = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<1x77xi1>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<1x77xi1>) -> tensor<77xi1, #{{.*}}> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 77 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<77xi1, #{{.*}}> +// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<1x77xi1> +// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<1x77xi1> +// CHECK: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_5]]) -> (tensor<77xi1, #{{.*}}>) { +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1> +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1 +// CHECK: %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}> +// CHECK: scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}> +// CHECK: return %[[VAL_15]] : tensor<77xi1, #{{.*}}> +// CHECK: } +func.func @main(%arg0: tensor<1x77xi1>, %arg1: tensor<1x77xi1>) -> tensor<77xi1, #SpVec> { + %0 = bufferization.alloc_tensor() : tensor<77xi1, #SpVec> + %1 = linalg.generic { + indexing_maps = [#map1, #map1, #map2], + iterator_types = ["parallel"]} + ins(%arg0, %arg1 : tensor<1x77xi1>, tensor<1x77xi1>) + outs(%0 : tensor<77xi1, #SpVec>) { + ^bb0(%in: i1, %in_0: i1, %out: i1): + %2 = arith.addi %in, %in_0 : i1 + linalg.yield %2 : i1 + } -> tensor<77xi1, #SpVec> + return %1 : tensor<77xi1, #SpVec> +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir index 97293348774c..2cda2335923c 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -496,4 +496,3 @@ func.func @mul_const_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>, } -> tensor<32x16xf64> return %0 : tensor<32x16xf64> } - diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index 270b5836907e..599e8abd52f3 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -128,7 +128,8 @@ class MergerTestBase : public ::testing::Test { protected: MergerTestBase(unsigned numTensors, unsigned numLoops) : numTensors(numTensors), numLoops(numLoops), - merger(numTensors, numLoops, /*numFilterLoops=*/0) {} + merger(numTensors, numLoops, /*numFilterLoops=*/0, + /*maxRank=*/numLoops) {} /// /// Expression construction helpers.