Skip to content

[mlir][PDL] Add support for native constraints with results #82760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
entities.
entities and optionally return a number of values.

Example:

```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
```
}];

let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
}];
let hasVerifier = 1;
}

Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
values. On success, this operation branches to the true destination,
values.
The constraint function may return any number of results.
On success, this operation branches to the true destination,
otherwise the false destination is taken. This behavior can be reversed
by setting the attribute `isNegated` to true.

Expand All @@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` attr-dict `->` successors
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
`->` successors
}];
}

Expand Down
18 changes: 11 additions & 7 deletions mlir/include/mlir/IR/PDLPatternMatch.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ protected:
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLConstraintFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;

/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
Expand Down Expand Up @@ -726,7 +727,7 @@ std::enable_if_t<
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
PatternRewriter &rewriter,
PatternRewriter &rewriter, PDLResultList &,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
Expand Down Expand Up @@ -842,10 +843,13 @@ public:
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
/// * `LogicalResult (PatternRewriter &,
/// PDLResultList &,
/// ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form.
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
Expand Down Expand Up @@ -960,8 +964,8 @@ public:
}
};
class PDLResultList {};
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLConstraintFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;

Expand Down
41 changes: 31 additions & 10 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ struct PatternLowering {

/// Generate interpreter operations for the tree rooted at the given matcher
/// node, in the specified region.
Block *generateMatcher(MatcherNode &node, Region &region);
Block *generateMatcher(MatcherNode &node, Region &region,
Block *block = nullptr);

/// Get or create an access to the provided positional value in the current
/// block. This operation may mutate the provided block pointer if nested
Expand Down Expand Up @@ -148,6 +149,10 @@ struct PatternLowering {
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;

/// A mapping from a constraint question to the ApplyConstraintOp
/// that implements it.
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
};
} // namespace

Expand Down Expand Up @@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
firstMatcherBlock->erase();
}

Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
Block *block) {
// Push a new scope for the values used by this matcher.
Block *block = &region.emplaceBlock();
if (!block)
block = &region.emplaceBlock();
ValueMapScope scope(values);

// If this is the return node, simply insert the corresponding interpreter
Expand Down Expand Up @@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
case Predicates::ConstraintResultPos: {
// Due to the order of traversal, the ApplyConstraintOp has already been
// created and we can find it in constraintOpMap.
auto *constrResPos = cast<ConstraintPosition>(pos);
auto i = constraintOpMap.find(constrResPos->getQuestion());
assert(i != constraintOpMap.end());
value = i->second->getResult(constrResPos->getIndex());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
Expand All @@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
args.push_back(getValueAt(currentBlock, position));
}

// Generate the matcher in the current (potentially nested) region
// and get the failure successor.
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
// Generate a new block as success successor and get the failure successor.
Block *success = &region->emplaceBlock();
Block *failure = failureBlockStack.back();

// Finally, create the predicate.
// Create the predicate.
builder.setInsertionPointToEnd(currentBlock);
Predicates::Kind kind = question->getKind();
switch (kind) {
Expand Down Expand Up @@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
failure);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
cstQuestion->getIsNegated(), success, failure);

constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}

// Generate the matcher in the current (potentially nested) region.
// This might use the results of the current predicate.
generateMatcher(*boolNode->getSuccessNode(), *region, success);
}

template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
Expand Down
58 changes: 47 additions & 11 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
Expand Down Expand Up @@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
bool isOperandDefiningOp() const;
};

//===----------------------------------------------------------------------===//
// ConstraintPosition

struct ConstraintQuestion;

/// A position describing the result of a native constraint. It saves the
/// corresponding ConstraintQuestion and result index to enable referring
/// back to them
struct ConstraintPosition
: public PredicateBase<ConstraintPosition, Position,
std::pair<ConstraintQuestion *, unsigned>,
Predicates::ConstraintResultPos> {
using PredicateBase::PredicateBase;

/// Returns the ConstraintQuestion to enable keeping track of the native
/// constraint this position stems from.
ConstraintQuestion *getQuestion() const { return key.first; }

// Returns the result index of this position
unsigned getIndex() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// ResultPosition

Expand Down Expand Up @@ -447,11 +470,13 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};

/// Apply a parameterized constraint to multiple position values.
/// Apply a parameterized constraint to multiple position values and possibly
/// produce results.
struct ConstraintQuestion
: public PredicateBase<ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, bool>,
Predicates::ConstraintQuestion> {
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
Predicates::ConstraintQuestion> {
using Base::Base;

/// Return the name of the constraint.
Expand All @@ -460,15 +485,19 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }

/// Return the result types of the constraint.
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }

/// Return the negation status of the constraint.
bool getIsNegated() const { return std::get<2>(key); }
bool getIsNegated() const { return std::get<3>(key); }

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
std::get<2>(key)});
alloc.copyInto(std::get<2>(key)),
std::get<3>(key)});
}

/// Returns a hash suitable for the given keytype.
Expand Down Expand Up @@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
Expand Down Expand Up @@ -588,6 +618,12 @@ class PredicateBuilder {
return OperationPosition::get(uniquer, p);
}

// Returns a position for a new value created by a constraint.
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
unsigned index) {
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
}

/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
Expand Down Expand Up @@ -673,11 +709,11 @@ class PredicateBuilder {
}

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
bool isNegated) {
return {
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
TrueAnswer::get(uniquer)};
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
ArrayRef<Type> resultTypes, bool isNegated) {
return {ConstraintQuestion::get(
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
TrueAnswer::get(uniquer)};
}

/// Create a predicate comparing a value with null.
Expand Down
Loading