Skip to content

[mlir][dataflow] disallow outside use of propagateIfChanged for DataFlowSolver #120885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ZenithalHourlyRate
Copy link
Member

Detailed writeup is in google/heir#1153. See also #120881. In short, propagateIfChanged is used outside of the DataFlowAnalysis scope, because it is public, but it does not propagate as expected as the DataFlowSolver has stopped running.

To solve such misuse, propagateIfChanged should be made protected/private.

For downstream users affected by this, to correctly propagate the change, the Analysis should be re-run (check #120881) instead of just a propagateIfChanged

The change to IntegerRangeAnalysis is just a expansion of the solver->propagateIfChanged. The Lattice has already been updated by the join. Propagation is done by onUpdate.

Cc @Mogball for review

@llvmbot llvmbot added the mlir label Dec 22, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 22, 2024

@llvm/pr-subscribers-mlir

Author: Hongren Zheng (ZenithalHourlyRate)

Changes

Detailed writeup is in google/heir#1153. See also #120881. In short, propagateIfChanged is used outside of the DataFlowAnalysis scope, because it is public, but it does not propagate as expected as the DataFlowSolver has stopped running.

To solve such misuse, propagateIfChanged should be made protected/private.

For downstream users affected by this, to correctly propagate the change, the Analysis should be re-run (check #120881) instead of just a propagateIfChanged

The change to IntegerRangeAnalysis is just a expansion of the solver->propagateIfChanged. The Lattice has already been updated by the join. Propagation is done by onUpdate.

Cc @Mogball for review


Full diff: https://github.com/llvm/llvm-project/pull/120885.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlowFramework.h (+7-3)
  • (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+12-5)
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 969664dc7a4fe3..1dcb32f762c003 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -394,13 +394,17 @@ class DataFlowSolver {
   template <typename StateT, typename AnchorT>
   StateT *getOrCreateState(AnchorT anchor);
 
+  /// Get the configuration of the solver.
+  const DataFlowConfig &getConfig() const { return config; }
+
+protected:
   /// Propagate an update to an analysis state if it changed by pushing
   /// dependent work items to the back of the queue.
+  /// This should only be used by DataFlowAnalysis instances.
+  /// When used outside of DataFlowAnalysis, the solver won't process
+  /// the work items
   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
 
-  /// Get the configuration of the solver.
-  const DataFlowConfig &getConfig() const { return config; }
-
 private:
   /// Configuration of the dataflow solver.
   DataFlowConfig config;
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9e9411e5ede12c..60ae7d00c0bbdb 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -45,9 +45,13 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   std::optional<APInt> constant = getValue().getValue().getConstantValue();
   auto value = cast<Value>(anchor);
   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
-  if (!constant)
-    return solver->propagateIfChanged(
-        cv, cv->join(ConstantValue::getUnknownConstant()));
+  if (!constant) {
+    auto changed = cv->join(ConstantValue::getUnknownConstant());
+    if (changed == ChangeResult::Change) {
+      cv->onUpdate(solver);
+    }
+    return;
+  }
 
   Dialect *dialect;
   if (auto *parent = value.getDefiningOp())
@@ -56,8 +60,11 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
     dialect = value.getParentBlock()->getParentOp()->getDialect();
 
   Type type = getElementTypeOrSelf(value);
-  solver->propagateIfChanged(
-      cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+  auto changed =
+      cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect));
+  if (changed == ChangeResult::Change) {
+    cv->onUpdate(solver);
+  }
 }
 
 LogicalResult IntegerRangeAnalysis::visitOperation(

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and I agree that this being a public method is a bit of a footgun. However, directly invoking this method inside other dataflow variables is the intended usage of the method. Do you have a better idea on how an API can be exposed that way? propagateIfChanged is supposed to be a shorthand so that users don't forget to check the result of join, for example. Maybe it should be marked as nodiscard? (Unless it already is...). Alternatively, placing an assert inside propagateIfChanged checking that the solver is running might do the trick.

if (!constant)
return solver->propagateIfChanged(
cv, cv->join(ConstantValue::getUnknownConstant()));
if (!constant) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually the intended usage of the propagateIfChanged method. Data flow variables that depend on each other can issue updates to each other through this method.

@ZenithalHourlyRate ZenithalHourlyRate force-pushed the mlir-dataflow-solver-propagate branch from 66f5e03 to 2dc277b Compare December 27, 2024 06:04
@ZenithalHourlyRate
Copy link
Member Author

Alternatively, placing an assert inside propagateIfChanged checking that the solver is running might do the trick.

Added a isRunning flag. It is working but management of that flag should be careful (unset it in each return).

Do you have a better idea on how an API can be exposed that way?

I was thinking that adding propagateIfChanged to AnalysisState so all children AnalysisState could access it.

class DataFlowSolver {
  friend class AnalysisState;
};

class AnalysisState {

  // existing API
  virtual void onUpdate(DataFlowSolver *solver) const;

  // NOTE: it is not virtual
  void propagateIfChanged(DataFlowSolver *solver, AnalysisState *state, ChangeResult changed) {
    solver->propagateIfChanged(state, changed);
  }

};

@ZenithalHourlyRate
Copy link
Member Author

I wonder if there is further feedback on this PR.

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping. I think this looks alright. Just one nit

@ZenithalHourlyRate ZenithalHourlyRate force-pushed the mlir-dataflow-solver-propagate branch from 2dc277b to b51abb9 Compare January 28, 2025 03:24
@ZenithalHourlyRate
Copy link
Member Author

Comments addressed.

@ZenithalHourlyRate ZenithalHourlyRate merged commit 3a439e2 into llvm:main Jan 28, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants