Skip to content

Commit 25f19c2

Browse files
committed
[mlir][smt] upstream SMT dialect
1 parent ff2ed15 commit 25f19c2

32 files changed

+3370
-0
lines changed

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_subdirectory(Ptr)
3333
add_subdirectory(Quant)
3434
add_subdirectory(SCF)
3535
add_subdirectory(Shape)
36+
add_subdirectory(SMT)
3637
add_subdirectory(SparseTensor)
3738
add_subdirectory(SPIRV)
3839
add_subdirectory(Tensor)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_dialect(SMT smt)
2+
add_mlir_doc(SMT SMT Dialects/SMTOps -gen-op-doc)
3+
# TODO(maX)
4+
#add_mlir_doc(SMT SMT Dialects/SMTTypes -gen-typedef-doc -dialect smt)
5+
6+
set(LLVM_TARGET_DEFINITIONS SMT.td)
7+
8+
mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls)
9+
mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs)
10+
add_public_tablegen_target(MLIRSMTAttrIncGen)
11+
add_dependencies(mlir-headers MLIRSMTAttrIncGen)
12+
13+
mlir_tablegen(SMTEnums.h.inc -gen-enum-decls)
14+
mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs)
15+
add_public_tablegen_target(MLIRSMTEnumsIncGen)
16+
add_dependencies(mlir-headers MLIRSMTEnumsIncGen)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- SMT.td - SMT dialect definition ---------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SMT_SMT_TD
10+
#define MLIR_DIALECT_SMT_SMT_TD
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
include "mlir/Dialect/SMT/IR/SMTAttributes.td"
15+
include "mlir/Dialect/SMT/IR/SMTDialect.td"
16+
include "mlir/Dialect/SMT/IR/SMTTypes.td"
17+
include "mlir/Dialect/SMT/IR/SMTOps.td"
18+
include "mlir/Dialect/SMT/IR/SMTArrayOps.td"
19+
include "mlir/Dialect/SMT/IR/SMTBitVectorOps.td"
20+
include "mlir/Dialect/SMT/IR/SMTIntOps.td"
21+
22+
#endif // MLIR_DIALECT_SMT_SMT_TD
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- SMTArrayOps.td - SMT array operations ---------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SMT_SMTARRAYOPS_TD
10+
#define MLIR_DIALECT_SMT_SMTARRAYOPS_TD
11+
12+
include "mlir/Dialect/SMT/IR/SMTDialect.td"
13+
include "mlir/Dialect/SMT/IR/SMTAttributes.td"
14+
include "mlir/Dialect/SMT/IR/SMTTypes.td"
15+
include "mlir/Interfaces/SideEffectInterfaces.td"
16+
17+
class SMTArrayOp<string mnemonic, list<Trait> traits = []> :
18+
SMTOp<"array." # mnemonic, traits>;
19+
20+
def ArrayStoreOp : SMTArrayOp<"store", [
21+
Pure,
22+
TypesMatchWith<"summary", "array", "index",
23+
"cast<ArrayType>($_self).getDomainType()">,
24+
TypesMatchWith<"summary", "array", "value",
25+
"cast<ArrayType>($_self).getRangeType()">,
26+
AllTypesMatch<["array", "result"]>,
27+
]> {
28+
let summary = "stores a value at a given index and returns the new array";
29+
let description = [{
30+
This operation returns a new array which is the same as the 'array' operand
31+
except that the value at the given 'index' is changed to the given 'value'.
32+
The semantics are equivalent to the 'store' operator described in the
33+
[SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of
34+
the SMT-LIB standard 2.6.
35+
}];
36+
37+
let arguments = (ins ArrayType:$array, AnySMTType:$index, AnySMTType:$value);
38+
let results = (outs ArrayType:$result);
39+
40+
let assemblyFormat = [{
41+
$array `[` $index `]` `,` $value attr-dict `:` qualified(type($array))
42+
}];
43+
}
44+
45+
def ArraySelectOp : SMTArrayOp<"select", [
46+
Pure,
47+
TypesMatchWith<"summary", "array", "index",
48+
"cast<ArrayType>($_self).getDomainType()">,
49+
TypesMatchWith<"summary", "array", "result",
50+
"cast<ArrayType>($_self).getRangeType()">,
51+
]> {
52+
let summary = "get the value stored in the array at the given index";
53+
let description = [{
54+
This operation is retuns the value stored in the given array at the given
55+
index. The semantics are equivalent to the `select` operator defined in the
56+
[SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of
57+
the SMT-LIB standard 2.6.
58+
}];
59+
60+
let arguments = (ins ArrayType:$array, AnySMTType:$index);
61+
let results = (outs AnySMTType:$result);
62+
63+
let assemblyFormat = [{
64+
$array `[` $index `]` attr-dict `:` qualified(type($array))
65+
}];
66+
}
67+
68+
def ArrayBroadcastOp : SMTArrayOp<"broadcast", [
69+
Pure,
70+
TypesMatchWith<"summary", "result", "value",
71+
"cast<ArrayType>($_self).getRangeType()">,
72+
]> {
73+
let summary = "construct an array with the given value stored at every index";
74+
let description = [{
75+
This operation represents a broadcast of the 'value' operand to all indices
76+
of the array. It is equivalent to
77+
```
78+
%0 = smt.declare "array" : !smt.array<[!smt.int -> !smt.bool]>
79+
%1 = smt.forall ["idx"] {
80+
^bb0(%idx: !smt.int):
81+
%2 = smt.array.select %0[%idx] : !smt.array<[!smt.int -> !smt.bool]>
82+
%3 = smt.eq %value, %2 : !smt.bool
83+
smt.yield %3 : !smt.bool
84+
}
85+
smt.assert %1
86+
// return %0
87+
```
88+
89+
In SMT-LIB, this is frequently written as
90+
`((as const (Array Int Bool)) value)`.
91+
}];
92+
93+
let arguments = (ins AnySMTType:$value);
94+
let results = (outs ArrayType:$result);
95+
96+
let assemblyFormat = "$value attr-dict `:` qualified(type($result))";
97+
}
98+
99+
#endif // MLIR_DIALECT_SMT_SMTARRAYOPS_TD
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_H
10+
#define MLIR_DIALECT_SMT_SMTATTRIBUTES_H
11+
12+
#include "mlir/IR/Attributes.h"
13+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
14+
#include "mlir/IR/BuiltinAttributes.h"
15+
16+
namespace mlir {
17+
namespace smt {
18+
namespace detail {
19+
20+
struct BitVectorAttrStorage;
21+
22+
} // namespace detail
23+
} // namespace smt
24+
} // namespace mlir
25+
26+
#define GET_ATTRDEF_CLASSES
27+
#include "mlir/Dialect/SMT/IR/SMTAttributes.h.inc"
28+
29+
#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_H
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines SMT dialect specific attributes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_TD
14+
#define MLIR_DIALECT_SMT_SMTATTRIBUTES_TD
15+
16+
include "mlir/Dialect/SMT/IR/SMTDialect.td"
17+
include "mlir/IR/EnumAttr.td"
18+
include "mlir/IR/BuiltinAttributeInterfaces.td"
19+
20+
def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
21+
DeclareAttrInterfaceMethods<TypedAttrInterface>
22+
]> {
23+
let mnemonic = "bv";
24+
let description = [{
25+
This attribute represents a constant value of the `(_ BitVec width)` sort as
26+
described in the [SMT bit-vector
27+
theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).
28+
29+
The constant is as #bX (binary) or #xX (hexadecimal) in SMT-LIB
30+
where X is the value in the corresponding format without any further
31+
prefixing. Here, the bit-vector constant is given as a regular integer
32+
literal and the associated bit-vector type indicating the bit-width.
33+
34+
Examples:
35+
```mlir
36+
#smt.bv<5> : !smt.bv<4>
37+
#smt.bv<92> : !smt.bv<8>
38+
```
39+
40+
The explicit type-suffix is mandatory to uniquely represent the attribute,
41+
i.e., this attribute should always be used in the extended form (using the
42+
`quantified` keyword in the operation assembly format string).
43+
44+
The bit-width must be greater than zero (i.e., at least one digit has to be
45+
present).
46+
}];
47+
48+
let parameters = (ins "llvm::APInt":$value);
49+
50+
let hasCustomAssemblyFormat = true;
51+
let genVerifyDecl = true;
52+
53+
// We need to manually define the storage class because the generated one is
54+
// buggy (because the APInt asserts matching bitwidth in the `==` operator and
55+
// the generated storage uses that directly.
56+
// Alternatively: add a type parameter to redundantly store the bitwidth of
57+
// of the attribute type, it it's in the order before the 'value' it will be
58+
// checked before the APInt equality (this is the reason it works for the
59+
// builtin integer attribute), but would be more fragile (and we'd store
60+
// duplicate data).
61+
let genStorageClass = false;
62+
63+
let builders = [
64+
AttrBuilder<(ins "llvm::StringRef":$value)>,
65+
AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,
66+
];
67+
68+
let extraClassDeclaration = [{
69+
/// Return the bit-vector constant as a SMT-LIB formatted string.
70+
std::string getValueAsString(bool prefix = true) const;
71+
}];
72+
}
73+
74+
#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_TD

0 commit comments

Comments
 (0)