Skip to content

Commit b1bb239

Browse files
authored
[mlir][sparse] merger cleanup (#70371)
Implemented some TODOs and removed unlikely ones. Comment cleanup
1 parent cfc922f commit b1bb239

File tree

2 files changed

+22
-68
lines changed

2 files changed

+22
-68
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 18 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,10 @@ struct TensorExp final {
122122
///
123123
/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124124
/// That is, its argument is a `LoopId` identifying the loop-variable
125-
/// in question, and its value will be the current iteration's value
126-
/// of that loop-variable. See the `LoopId` documentation for more details.
127-
///
128-
/// The `kSynZero` leaf kind is for representing a synthetic zero value, which
129-
/// can be introduced when sparsifying operations like `arith::cmp` to generate
130-
/// `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
131-
//
132-
// TODO: Modify this definition so that the numeric values already encode
133-
// the `ExpArity` (while extending the notion of "arity" to include not
134-
// just the number of `ExprId` children the node has, but also whether the
135-
// node has a `Value` and/or `Operation*`). Doing this will avoid needing
136-
// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
137-
// and should help clean up a few other places as well.
125+
/// in question, and its value will be the current iteration's value.
126+
/// The `kSynZero` leaf kind is for representing a synthetic zero value,
127+
/// which can be introduced when sparsifying operations like `arith::cmp`
128+
/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
138129
enum class TensorExp::Kind {
139130
// Leaf.
140131
kTensor = 0,
@@ -253,15 +244,6 @@ class Merger {
253244
///
254245
/// The maxLvlRank specifies the max level rank of all inputs/output tensors.
255246
/// It is used to pre-allocate sufficient memory for internal storage.
256-
//
257-
// TODO: we want to make the filter loop more efficient in the future,
258-
// e.g., by avoiding scanning the full list of stored coordinates (keeping
259-
// the last position in ordered list) or even apply binary search to find
260-
// the coordinate.
261-
//
262-
// TODO: would be cleaner to understand/document if the first argument
263-
// gave the number of input tensors, instead of the current number of
264-
// input+output tensors.
265247
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
266248
unsigned numFilterLoops, unsigned maxLvlRank);
267249

@@ -383,12 +365,15 @@ class Merger {
383365

384366
/// Gets the total number of loops (native loops + filter loops).
385367
constexpr unsigned getNumLoops() const { return numLoops; }
368+
386369
/// Gets the number of native loops.
387370
constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }
371+
388372
/// Gets the number of filter loops.
389373
constexpr unsigned getNumFilterLoops() const {
390374
return numLoops - numNativeLoops;
391375
}
376+
392377
/// Gets the identifier of the first filter-loop.
393378
constexpr LoopId getStartingFilterLoopId() const {
394379
return getNumNativeLoops();
@@ -473,8 +458,7 @@ class Merger {
473458
lvlTypes[t][i] = dlt;
474459
loopToLvl[t][i] = lvl;
475460
lvlToLoop[t][lvl] = i;
476-
// TODO: Maybe we should favor a constant loop bound when there are multiple
477-
// choices.
461+
// TODO: favor a constant loop bound when there are multiple choices.
478462
loopBounds[i] = std::make_pair(t, lvl);
479463
}
480464

@@ -600,43 +584,19 @@ class Merger {
600584
/// Checks whether the given expression has an associated value.
601585
bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
602586

603-
/// Sets the expression to have the associated value. Asserts that
604-
/// the new value is defined, and that the expression does not already
605-
/// have a value. If you want to overwrite a previous associated value,
606-
/// use `updateExprValue` instead.
587+
/// Sets the expression to have the associated value. Asserts that the new
588+
/// value is defined, and that the expression does not already have a value.
607589
void setExprValue(ExprId e, Value v) {
608-
assert(isValidExprId(e));
609-
assert(v && "Got an undefined value");
610-
auto &val = tensorExps[e].val;
611-
assert(!val && "Expression already has an associated value");
612-
val = v;
590+
assert(!exp(e).val && "Expression already has an associated value");
591+
assert(v && "Trying to assign an undefined value");
592+
tensorExps[e].val = v;
613593
}
614594

615-
/// Clears the value associated with the expression. Asserts that the
595+
/// Clears the value associated with the expression. Asserts that the
616596
/// expression does indeed have an associated value before clearing it.
617-
/// If you don't want to check for a previous associated value first,
618-
/// then use `updateExprValue` instead.
619597
void clearExprValue(ExprId e) {
620-
assert(isValidExprId(e));
621-
auto &val = tensorExps[e].val;
622-
assert(val && "Expression does not have an associated value to clear");
623-
val = Value();
624-
}
625-
626-
/// Unilaterally updates the expression to have the associated value.
627-
/// That is, unlike `setExprValue` and `clearExprValue`, this method
628-
/// does not perform any checks on whether the expression had a
629-
/// previously associated value nor whether the new value is defined.
630-
//
631-
// TODO: The unilateral update semantics are required by the
632-
// current implementation of `CodegenEnv::genLoopBoundary`; however,
633-
// that implementation seems a bit dubious. We would much rather have
634-
// the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
635-
// `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
636-
// provide better invariants.
637-
void updateExprValue(ExprId e, Value v) {
638-
assert(isValidExprId(e));
639-
tensorExps[e].val = v;
598+
assert(exp(e).val && "Expression does not have an associated value");
599+
tensorExps[e].val = Value();
640600
}
641601

642602
#ifndef NDEBUG
@@ -706,12 +666,10 @@ class Merger {
706666
// `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
707667
// does not.
708668

709-
/// Map that converts pair<TensorId, LoopId> to the corresponding
710-
/// level-type.
669+
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
711670
std::vector<std::vector<DimLevelType>> lvlTypes;
712671

713-
/// Map that converts pair<TensorId, LoopId> to the corresponding
714-
/// level.
672+
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
715673
std::vector<std::vector<std::optional<Level>>> loopToLvl;
716674

717675
/// Map that converts pair<TensorId, Level> to the corresponding LoopId.

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
137137
auto r = callback(params); // may update parameters
138138
unsigned i = 0;
139139
if (isReduc()) {
140-
// FIXME: This requires `updateExprValue` to perform updates without
141-
// checking for a previous value; but it's not clear whether that's
142-
// by design or might be a potential source for bugs.
143140
updateReduc(params[i++]);
144141
if (redValidLexInsert)
145142
setValidLexInsert(params[i++]);
@@ -283,16 +280,15 @@ void CodegenEnv::endExpand() {
283280
void CodegenEnv::startReduc(ExprId exp, Value val) {
284281
assert(!isReduc() && exp != detail::kInvalidId);
285282
redExp = exp;
286-
updateReduc(val);
283+
redVal = val;
284+
latticeMerger.setExprValue(exp, val);
287285
}
288286

289287
void CodegenEnv::updateReduc(Value val) {
290288
assert(isReduc());
291289
redVal = val;
292-
// NOTE: `genLoopBoundary` requires that this performs a unilateral
293-
// update without checking for a previous value first. (It's not
294-
// clear whether any other callsites also require that.)
295-
latticeMerger.updateExprValue(redExp, val);
290+
latticeMerger.clearExprValue(redExp);
291+
latticeMerger.setExprValue(redExp, val);
296292
}
297293

298294
Value CodegenEnv::endReduc() {

0 commit comments

Comments
 (0)