diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 6e39404bb28a..991c920c1739 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -280,6 +280,7 @@ public: /// Constructs a new iteration lattice point, and returns its identifier. LatPointId addLat(TensorId t, LoopId i, ExprId e); + LatPointId addLat(const BitVector &bits, ExprId e); /// Constructs a new (initially empty) set, and returns its identifier. LatSetId addSet(); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 4a8c3cbfbe58..0691d2554f43 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -247,6 +247,13 @@ LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { return p; } +LatPointId Merger::addLat(const BitVector &bits, ExprId e) { + assert(bits.size() == numLoops * numTensors); + const LatPointId p = latPoints.size(); + latPoints.emplace_back(bits, e); + return p; +} + LatSetId Merger::addSet() { const LatSetId s = latSets.size(); latSets.emplace_back(); @@ -322,8 +329,7 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, const LatSetId s = addSet(); for (const LatPointId p : latSets[s0]) { const ExprId e = addExp(kind, latPoints[p].exp, v, op); - latPoints.emplace_back(latPoints[p].bits, e); - latSets[s].push_back(latPoints.size() - 1); + latSets[s].push_back(addLat(latPoints[p].bits, e)); } return s; }