@@ -122,19 +122,10 @@ struct TensorExp final {
122
122
// /
123
123
// / The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124
124
// / That is, its argument is a `LoopId` identifying the loop-variable
125
- // / in question, and its value will be the current iteration's value
126
- // / of that loop-variable. See the `LoopId` documentation for more details.
127
- // /
128
- // / The `kSynZero` leaf kind is for representing a synthetic zero value, which
129
- // / can be introduced when sparsifying operations like `arith::cmp` to generate
130
- // / `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
131
- //
132
- // TODO: Modify this definition so that the numeric values already encode
133
- // the `ExpArity` (while extending the notion of "arity" to include not
134
- // just the number of `ExprId` children the node has, but also whether the
135
- // node has a `Value` and/or `Operation*`). Doing this will avoid needing
136
- // to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
137
- // and should help clean up a few other places as well.
125
+ // / in question, and its value will be the current iteration's value.
126
+ // / The `kSynZero` leaf kind is for representing a synthetic zero value,
127
+ // / which can be introduced when sparsifying operations like `arith::cmp`
128
+ // / to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
138
129
enum class TensorExp ::Kind {
139
130
// Leaf.
140
131
kTensor = 0 ,
@@ -253,15 +244,6 @@ class Merger {
253
244
// /
254
245
// / The maxLvlRank specifies the max level rank of all inputs/output tensors.
255
246
// / It is used to pre-allocate sufficient memory for internal storage.
256
- //
257
- // TODO: we want to make the filter loop more efficient in the future,
258
- // e.g., by avoiding scanning the full list of stored coordinates (keeping
259
- // the last position in ordered list) or even apply binary search to find
260
- // the coordinate.
261
- //
262
- // TODO: would be cleaner to understand/document if the first argument
263
- // gave the number of input tensors, instead of the current number of
264
- // input+output tensors.
265
247
Merger (unsigned numInputOutputTensors, unsigned numNativeLoops,
266
248
unsigned numFilterLoops, unsigned maxLvlRank);
267
249
@@ -383,12 +365,15 @@ class Merger {
383
365
384
366
// / Gets the total number of loops (native loops + filter loops).
385
367
constexpr unsigned getNumLoops () const { return numLoops; }
368
+
386
369
// / Gets the number of native loops.
387
370
constexpr unsigned getNumNativeLoops () const { return numNativeLoops; }
371
+
388
372
// / Gets the number of filter loops.
389
373
constexpr unsigned getNumFilterLoops () const {
390
374
return numLoops - numNativeLoops;
391
375
}
376
+
392
377
// / Gets the identifier of the first filter-loop.
393
378
constexpr LoopId getStartingFilterLoopId () const {
394
379
return getNumNativeLoops ();
@@ -473,8 +458,7 @@ class Merger {
473
458
lvlTypes[t][i] = dlt;
474
459
loopToLvl[t][i] = lvl;
475
460
lvlToLoop[t][lvl] = i;
476
- // TODO: Maybe we should favor a constant loop bound when there are multiple
477
- // choices.
461
+ // TODO: favor a constant loop bound when there are multiple choices.
478
462
loopBounds[i] = std::make_pair (t, lvl);
479
463
}
480
464
@@ -600,43 +584,19 @@ class Merger {
600
584
// / Checks whether the given expression has an associated value.
601
585
bool hasExprValue (ExprId e) const { return static_cast <bool >(exp (e).val ); }
602
586
603
- // / Sets the expression to have the associated value. Asserts that
604
- // / the new value is defined, and that the expression does not already
605
- // / have a value. If you want to overwrite a previous associated value,
606
- // / use `updateExprValue` instead.
587
+ // / Sets the expression to have the associated value. Asserts that the new
588
+ // / value is defined, and that the expression does not already have a value.
607
589
void setExprValue (ExprId e, Value v) {
608
- assert (isValidExprId (e));
609
- assert (v && " Got an undefined value" );
610
- auto &val = tensorExps[e].val ;
611
- assert (!val && " Expression already has an associated value" );
612
- val = v;
590
+ assert (!exp (e).val && " Expression already has an associated value" );
591
+ assert (v && " Trying to assign an undefined value" );
592
+ tensorExps[e].val = v;
613
593
}
614
594
615
- // / Clears the value associated with the expression. Asserts that the
595
+ // / Clears the value associated with the expression. Asserts that the
616
596
// / expression does indeed have an associated value before clearing it.
617
- // / If you don't want to check for a previous associated value first,
618
- // / then use `updateExprValue` instead.
619
597
void clearExprValue (ExprId e) {
620
- assert (isValidExprId (e));
621
- auto &val = tensorExps[e].val ;
622
- assert (val && " Expression does not have an associated value to clear" );
623
- val = Value ();
624
- }
625
-
626
- // / Unilaterally updates the expression to have the associated value.
627
- // / That is, unlike `setExprValue` and `clearExprValue`, this method
628
- // / does not perform any checks on whether the expression had a
629
- // / previously associated value nor whether the new value is defined.
630
- //
631
- // TODO: The unilateral update semantics are required by the
632
- // current implementation of `CodegenEnv::genLoopBoundary`; however,
633
- // that implementation seems a bit dubious. We would much rather have
634
- // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
635
- // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
636
- // provide better invariants.
637
- void updateExprValue (ExprId e, Value v) {
638
- assert (isValidExprId (e));
639
- tensorExps[e].val = v;
598
+ assert (exp (e).val && " Expression does not have an associated value" );
599
+ tensorExps[e].val = Value ();
640
600
}
641
601
642
602
#ifndef NDEBUG
@@ -706,12 +666,10 @@ class Merger {
706
666
// `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
707
667
// does not.
708
668
709
- // / Map that converts pair<TensorId, LoopId> to the corresponding
710
- // / level-type.
669
+ // / Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
711
670
std::vector<std::vector<DimLevelType>> lvlTypes;
712
671
713
- // / Map that converts pair<TensorId, LoopId> to the corresponding
714
- // / level.
672
+ // / Map that converts pair<TensorId, LoopId> to the corresponding lvl.
715
673
std::vector<std::vector<std::optional<Level>>> loopToLvl;
716
674
717
675
// / Map that converts pair<TensorId, Level> to the corresponding LoopId.
0 commit comments