Skip to content

Commit 930916c

Browse files
author
Mogball
committed
[MLIR][PDL] Add PDLL support for negated native constraints
This commit enables the expression of negated native constraints in PDLL: If a constraint is prefixed with "not" it is parsed as a negated constraint and hence the attribute `isNegated` of the emitted `pdl.apply_native_constraint` operation is set to `true`. In first instance this is only supported for the calling of external native C++ constraints and generation of PDL patterns. Previously, negating a native constraint would have been handled by creating an additional native call, e.g. ```PDLL Constraint checkA(input: Attr); Constarint checkNotA(input: Attr); ``` or by including an explicit additional operand for negation, e.g. `Constraint checkA(input: Attr, negated: Attr);` With this a constraint can simply be negated by prefixing it with `not`. e.g. ```PDLL Constraint simpleConstraint(op: Op); Pattern example { let inputOp = op<test.bar>() ->(type: Type); let root = op<test.foo>(inputOp.0) -> (); not simpleConstraint(inputOp); simpleConstraint(root); erase root; } ``` Depends on [[ https://reviews.llvm.org/D153871 | D153871 ]] Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D153959
1 parent 2d696d0 commit 930916c

File tree

10 files changed

+120
-25
lines changed

10 files changed

+120
-25
lines changed

mlir/include/mlir/Tools/PDLL/AST/Nodes.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
390390
private llvm::TrailingObjects<CallExpr, Expr *> {
391391
public:
392392
static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
393-
ArrayRef<Expr *> arguments, Type resultType);
393+
ArrayRef<Expr *> arguments, Type resultType,
394+
bool isNegated = false);
394395

395396
/// Return the callable of this call.
396397
Expr *getCallableExpr() const { return callable; }
@@ -403,9 +404,14 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
403404
return const_cast<CallExpr *>(this)->getArguments();
404405
}
405406

407+
/// Returns whether the result of this call is to be negated.
408+
bool getIsNegated() const { return isNegated; }
409+
406410
private:
407-
CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
408-
: Base(loc, type), callable(callable), numArgs(numArgs) {}
411+
CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
412+
bool isNegated)
413+
: Base(loc, type), callable(callable), numArgs(numArgs),
414+
isNegated(isNegated) {}
409415

410416
/// The callable of this call.
411417
Expr *callable;
@@ -415,6 +421,9 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
415421

416422
/// TrailingObject utilities.
417423
friend llvm::TrailingObjects<CallExpr, Expr *>;
424+
425+
// Is the result of this call to be negated.
426+
bool isNegated;
418427
};
419428

420429
//===----------------------------------------------------------------------===//

mlir/lib/Tools/PDLL/AST/NodePrinter.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
225225
void NodePrinter::printImpl(const CallExpr *expr) {
226226
os << "CallExpr " << expr << " Type<";
227227
print(expr->getType());
228-
os << ">\n";
228+
os << ">";
229+
if (expr->getIsNegated())
230+
os << " Negated";
231+
os << "\n";
229232
printChildren(expr->getCallableExpr());
230233
printChildren("Arguments", expr->getArguments());
231234
}

mlir/lib/Tools/PDLL/AST/Nodes.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
266266
//===----------------------------------------------------------------------===//
267267

268268
CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
269-
ArrayRef<Expr *> arguments, Type resultType) {
269+
ArrayRef<Expr *> arguments, Type resultType,
270+
bool isNegated) {
270271
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
271272
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
272273

273-
CallExpr *expr =
274-
new (rawData) CallExpr(loc, resultType, callable, arguments.size());
274+
CallExpr *expr = new (rawData)
275+
CallExpr(loc, resultType, callable, arguments.size(), isNegated);
275276
std::uninitialized_copy(arguments.begin(), arguments.end(),
276277
expr->getArguments().begin());
277278
return expr;

mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,14 @@ class CodeGen {
103103
Value genExprImpl(const ast::TypeExpr *expr);
104104

105105
SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
106-
Location loc, ValueRange inputs);
106+
Location loc, ValueRange inputs,
107+
bool isNegated = false);
107108
SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
108109
Location loc, ValueRange inputs);
109110
template <typename PDLOpT, typename T>
110111
SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
111-
ValueRange inputs);
112+
ValueRange inputs,
113+
bool isNegated = false);
112114

113115
//===--------------------------------------------------------------------===//
114116
// Fields
@@ -419,7 +421,7 @@ SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
419421
// Generate the PDL based on the type of callable.
420422
const ast::Decl *callable = callableExpr->getDecl();
421423
if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
422-
return genConstraintCall(decl, loc, arguments);
424+
return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
423425
if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
424426
return genRewriteCall(decl, loc, arguments);
425427
llvm_unreachable("unhandled CallExpr callable");
@@ -547,15 +549,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
547549

548550
SmallVector<Value>
549551
CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
550-
ValueRange inputs) {
552+
ValueRange inputs, bool isNegated) {
551553
// Apply any constraints defined on the arguments to the input values.
552554
for (auto it : llvm::zip(decl->getInputs(), inputs))
553555
applyVarConstraints(std::get<0>(it), std::get<1>(it));
554556

555557
// Generate the constraint call.
556558
SmallVector<Value> results =
557-
genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
558-
inputs);
559+
genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560+
decl, loc, inputs, isNegated);
559561

560562
// Apply any constraints defined on the results of the constraint.
561563
for (auto it : llvm::zip(decl->getResults(), results))
@@ -570,9 +572,9 @@ SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
570572
}
571573

572574
template <typename PDLOpT, typename T>
573-
SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
574-
Location loc,
575-
ValueRange inputs) {
575+
SmallVector<Value>
576+
CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
577+
ValueRange inputs, bool isNegated) {
576578
const ast::CompoundStmt *cstBody = decl->getBody();
577579

578580
// If the decl doesn't have a statement body, it is a native decl.
@@ -585,8 +587,10 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
585587
} else {
586588
resultTypes.push_back(genType(declResultType));
587589
}
588-
Operation *pdlOp = builder.create<PDLOpT>(
590+
PDLOpT pdlOp = builder.create<PDLOpT>(
589591
loc, resultTypes, decl->getName().getName(), inputs);
592+
if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593+
cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
590594
return pdlOp->getResults();
591595
}
592596

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
315315
.Case("erase", Token::kw_erase)
316316
.Case("let", Token::kw_let)
317317
.Case("Constraint", Token::kw_Constraint)
318+
.Case("not", Token::kw_not)
318319
.Case("op", Token::kw_op)
319320
.Case("Op", Token::kw_Op)
320321
.Case("OpName", Token::kw_OpName)

mlir/lib/Tools/PDLL/Parser/Lexer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Token {
5757
kw_erase,
5858
kw_let,
5959
kw_Constraint,
60+
kw_not,
6061
kw_Op,
6162
kw_OpName,
6263
kw_Pattern,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,14 @@ class Parser {
315315

316316
/// Identifier expressions.
317317
FailureOr<ast::Expr *> parseAttributeExpr();
318-
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
318+
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
319+
bool isNegated = false);
319320
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320321
FailureOr<ast::Expr *> parseIdentifierExpr();
321322
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322323
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323324
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
325+
FailureOr<ast::Expr *> parseNegatedExpr();
324326
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325327
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326328
FailureOr<ast::Expr *>
@@ -405,7 +407,8 @@ class Parser {
405407

406408
FailureOr<ast::CallExpr *>
407409
createCallExpr(SMRange loc, ast::Expr *parentExpr,
408-
MutableArrayRef<ast::Expr *> arguments);
410+
MutableArrayRef<ast::Expr *> arguments,
411+
bool isNegated = false);
409412
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
410413
FailureOr<ast::DeclRefExpr *>
411414
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
@@ -1805,6 +1808,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
18051808
case Token::kw_Constraint:
18061809
lhsExpr = parseInlineConstraintLambdaExpr();
18071810
break;
1811+
case Token::kw_not:
1812+
lhsExpr = parseNegatedExpr();
1813+
break;
18081814
case Token::identifier:
18091815
lhsExpr = parseIdentifierExpr();
18101816
break;
@@ -1866,7 +1872,8 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
18661872
return ast::AttributeExpr::create(ctx, loc, attrExpr);
18671873
}
18681874

1869-
FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1875+
FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1876+
bool isNegated) {
18701877
consumeToken(Token::l_paren);
18711878

18721879
// Parse the arguments of the call.
@@ -1890,7 +1897,7 @@ FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
18901897
if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
18911898
return failure();
18921899

1893-
return createCallExpr(loc, parentExpr, arguments);
1900+
return createCallExpr(loc, parentExpr, arguments, isNegated);
18941901
}
18951902

18961903
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
@@ -1959,6 +1966,17 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
19591966
return createMemberAccessExpr(parentExpr, memberName, loc);
19601967
}
19611968

1969+
FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1970+
consumeToken(Token::kw_not);
1971+
// Only native constraints are supported after negation
1972+
if (!curToken.is(Token::identifier))
1973+
return emitError("expected native constraint");
1974+
FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1975+
if (failed(identifierExpr))
1976+
return failure();
1977+
return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1978+
}
1979+
19621980
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
19631981
SMRange loc = curToken.getLoc();
19641982

@@ -2672,7 +2690,7 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
26722690

26732691
FailureOr<ast::CallExpr *>
26742692
Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2675-
MutableArrayRef<ast::Expr *> arguments) {
2693+
MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
26762694
ast::Type parentType = parentExpr->getType();
26772695

26782696
ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
@@ -2686,8 +2704,14 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
26862704
if (isa<ast::UserConstraintDecl>(callableDecl))
26872705
return emitError(
26882706
loc, "unable to invoke `Constraint` within a rewrite section");
2689-
} else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2690-
return emitError(loc, "unable to invoke `Rewrite` within a match section");
2707+
if (isNegated)
2708+
return emitError(loc, "unable to negate a Rewrite");
2709+
} else {
2710+
if (isa<ast::UserRewriteDecl>(callableDecl))
2711+
return emitError(loc,
2712+
"unable to invoke `Rewrite` within a match section");
2713+
if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714+
return emitError(loc, "unable to negate non native constraints");
26912715
}
26922716

26932717
// Verify the arguments of the call.
@@ -2718,7 +2742,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
27182742
}
27192743

27202744
return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2721-
callableDecl->getResultType());
2745+
callableDecl->getResultType(), isNegated);
27222746
}
27232747

27242748
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,

mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ Pattern TestExternalCall => replace root: Op with TestRewrite(root);
3636

3737
// -----
3838

39+
// CHECK: pdl.pattern @TestExternalNegatedCall
40+
// CHECK: %[[ROOT:.*]] = operation
41+
// CHECK: apply_native_constraint "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true}
42+
// CHECK: rewrite %[[ROOT]]
43+
// CHECK: erase %[[ROOT]]
44+
Constraint TestConstraint(op: Op);
45+
Pattern TestExternalNegatedCall {
46+
let root = op : Op;
47+
not TestConstraint(root);
48+
erase root;
49+
}
50+
51+
// -----
52+
3953
//===----------------------------------------------------------------------===//
4054
// MemberAccessExpr
4155
//===----------------------------------------------------------------------===//

mlir/test/mlir-pdll/Parser/expr-failure.pdll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ Pattern {
173173

174174
// -----
175175

176+
Constraint Foo(op: Op) {}
177+
178+
Pattern {
179+
// CHECK: unable to negate non native constraints
180+
let root = op<>;
181+
not Foo(root);
182+
}
183+
184+
// -----
185+
176186
Rewrite Foo();
177187

178188
Pattern {
@@ -183,6 +193,18 @@ Pattern {
183193

184194
// -----
185195

196+
Rewrite Foo(op: Op);
197+
198+
Pattern {
199+
// CHECK: unable to negate a Rewrite
200+
let root = op<>;
201+
rewrite root with {
202+
not Foo(root);
203+
}
204+
}
205+
206+
// -----
207+
186208
Pattern {
187209
// CHECK: expected expression
188210
let tuple = (10 = _: Value);

mlir/test/mlir-pdll/Parser/expr.pdll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ Pattern {
5050

5151
// -----
5252

53+
// CHECK: Module {{.*}}
54+
// CHECK: -UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
55+
// CHECK: `-PatternDecl {{.*}}
56+
// CHECK: -CallExpr {{.*}} Type<Tuple<>> Negated
57+
// CHECK: `-DeclRefExpr {{.*}} Type<Constraint>
58+
// CHECK: `-UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
59+
Constraint TestConstraint(op: Op);
60+
61+
Pattern {
62+
let inputOp = op<my_dialect.bar>;
63+
not TestConstraint(inputOp);
64+
erase inputOp;
65+
}
66+
67+
// -----
68+
5369
//===----------------------------------------------------------------------===//
5470
// MemberAccessExpr
5571
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)