diff --git a/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h b/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h new file mode 100644 index 0000000000000..5bb76321199ee --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h @@ -0,0 +1,30 @@ +//===- ArithToSMT.h - Arith to SMT dialect conversion ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOSMT_H +#define MLIR_CONVERSION_ARITHTOSMT_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { + +class TypeConverter; +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTARITHTOSMT +#include "mlir/Conversion/Passes.h.inc" + +namespace arith { +/// Get the Arith to SMT conversion patterns. +void populateArithToSMTConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns); +} // namespace arith +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOSMT_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index ccd862f67c068..b3a65d611a0d5 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ArithToSMT/ArithToSMT.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..89d4c3c0b35b7 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1464,4 +1464,17 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { ]; } +//===----------------------------------------------------------------------===// +// ConvertArithToSMT +//===----------------------------------------------------------------------===// + +def ConvertArithToSMT : Pass<"convert-arith-to-smt"> { + let summary = "Convert arith ops and constants to SMT ops"; + let dependentDialects = [ + "smt::SMTDialect", + "arith::ArithDialect", + "mlir::func::FuncDialect" + ]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index f710235197334..9d1a840d6644b 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -33,6 +33,7 @@ add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) +add_subdirectory(SMT) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Tensor) diff --git a/mlir/include/mlir/Dialect/SMT/CMakeLists.txt b/mlir/include/mlir/Dialect/SMT/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt new file mode 100644 index 0000000000000..bd743ed510a9e --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect(SMT smt) +add_mlir_doc(SMT SMT Dialects/SMTOps -gen-op-doc) +# TODO(maX) +#add_mlir_doc(SMT SMT Dialects/SMTTypes -gen-typedef-doc -dialect smt) + +set(LLVM_TARGET_DEFINITIONS SMT.td) + +mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRSMTAttrIncGen) +add_dependencies(mlir-headers MLIRSMTAttrIncGen) + +mlir_tablegen(SMTEnums.h.inc -gen-enum-decls) +mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRSMTEnumsIncGen) +add_dependencies(mlir-headers MLIRSMTEnumsIncGen) diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMT.td b/mlir/include/mlir/Dialect/SMT/IR/SMT.td new file mode 100644 index 0000000000000..dd7bd033c9fa5 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMT.td @@ -0,0 +1,22 @@ +//===- SMT.td - SMT dialect definition ---------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMT_TD +#define MLIR_DIALECT_SMT_SMT_TD + +include "mlir/IR/OpBase.td" + +include "mlir/Dialect/SMT/IR/SMTAttributes.td" +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" +include "mlir/Dialect/SMT/IR/SMTOps.td" +include "mlir/Dialect/SMT/IR/SMTArrayOps.td" +include "mlir/Dialect/SMT/IR/SMTBitVectorOps.td" +include "mlir/Dialect/SMT/IR/SMTIntOps.td" + +#endif // MLIR_DIALECT_SMT_SMT_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td new file mode 100644 index 0000000000000..05b5398b6a7f9 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td @@ -0,0 +1,99 @@ +//===- SMTArrayOps.td - SMT array operations ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTARRAYOPS_TD +#define MLIR_DIALECT_SMT_SMTARRAYOPS_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/Dialect/SMT/IR/SMTAttributes.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SMTArrayOp traits = []> : + SMTOp<"array." # mnemonic, traits>; + +def ArrayStoreOp : SMTArrayOp<"store", [ + Pure, + TypesMatchWith<"summary", "array", "index", + "cast($_self).getDomainType()">, + TypesMatchWith<"summary", "array", "value", + "cast($_self).getRangeType()">, + AllTypesMatch<["array", "result"]>, +]> { + let summary = "stores a value at a given index and returns the new array"; + let description = [{ + This operation returns a new array which is the same as the 'array' operand + except that the value at the given 'index' is changed to the given 'value'. + The semantics are equivalent to the 'store' operator described in the + [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of + the SMT-LIB standard 2.6. + }]; + + let arguments = (ins ArrayType:$array, AnySMTType:$index, AnySMTType:$value); + let results = (outs ArrayType:$result); + + let assemblyFormat = [{ + $array `[` $index `]` `,` $value attr-dict `:` qualified(type($array)) + }]; +} + +def ArraySelectOp : SMTArrayOp<"select", [ + Pure, + TypesMatchWith<"summary", "array", "index", + "cast($_self).getDomainType()">, + TypesMatchWith<"summary", "array", "result", + "cast($_self).getRangeType()">, +]> { + let summary = "get the value stored in the array at the given index"; + let description = [{ + This operation is retuns the value stored in the given array at the given + index. The semantics are equivalent to the `select` operator defined in the + [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of + the SMT-LIB standard 2.6. + }]; + + let arguments = (ins ArrayType:$array, AnySMTType:$index); + let results = (outs AnySMTType:$result); + + let assemblyFormat = [{ + $array `[` $index `]` attr-dict `:` qualified(type($array)) + }]; +} + +def ArrayBroadcastOp : SMTArrayOp<"broadcast", [ + Pure, + TypesMatchWith<"summary", "result", "value", + "cast($_self).getRangeType()">, +]> { + let summary = "construct an array with the given value stored at every index"; + let description = [{ + This operation represents a broadcast of the 'value' operand to all indices + of the array. It is equivalent to + ``` + %0 = smt.declare "array" : !smt.array<[!smt.int -> !smt.bool]> + %1 = smt.forall ["idx"] { + ^bb0(%idx: !smt.int): + %2 = smt.array.select %0[%idx] : !smt.array<[!smt.int -> !smt.bool]> + %3 = smt.eq %value, %2 : !smt.bool + smt.yield %3 : !smt.bool + } + smt.assert %1 + // return %0 + ``` + + In SMT-LIB, this is frequently written as + `((as const (Array Int Bool)) value)`. + }]; + + let arguments = (ins AnySMTType:$value); + let results = (outs ArrayType:$result); + + let assemblyFormat = "$value attr-dict `:` qualified(type($result))"; +} + +#endif // MLIR_DIALECT_SMT_SMTARRAYOPS_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h new file mode 100644 index 0000000000000..590364d572699 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h @@ -0,0 +1,29 @@ +//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_H +#define MLIR_DIALECT_SMT_SMTATTRIBUTES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir { +namespace smt { +namespace detail { + +struct BitVectorAttrStorage; + +} // namespace detail +} // namespace smt +} // namespace mlir + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/SMT/IR/SMTAttributes.h.inc" + +#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_H diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td new file mode 100644 index 0000000000000..4231363fdf05b --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td @@ -0,0 +1,74 @@ +//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines SMT dialect specific attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_TD +#define MLIR_DIALECT_SMT_SMTATTRIBUTES_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" + +def BitVectorAttr : AttrDef +]> { + let mnemonic = "bv"; + let description = [{ + This attribute represents a constant value of the `(_ BitVec width)` sort as + described in the [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The constant is as #bX (binary) or #xX (hexadecimal) in SMT-LIB + where X is the value in the corresponding format without any further + prefixing. Here, the bit-vector constant is given as a regular integer + literal and the associated bit-vector type indicating the bit-width. + + Examples: + ```mlir + #smt.bv<5> : !smt.bv<4> + #smt.bv<92> : !smt.bv<8> + ``` + + The explicit type-suffix is mandatory to uniquely represent the attribute, + i.e., this attribute should always be used in the extended form (using the + `quantified` keyword in the operation assembly format string). + + The bit-width must be greater than zero (i.e., at least one digit has to be + present). + }]; + + let parameters = (ins "llvm::APInt":$value); + + let hasCustomAssemblyFormat = true; + let genVerifyDecl = true; + + // We need to manually define the storage class because the generated one is + // buggy (because the APInt asserts matching bitwidth in the `==` operator and + // the generated storage uses that directly. + // Alternatively: add a type parameter to redundantly store the bitwidth of + // of the attribute type, it it's in the order before the 'value' it will be + // checked before the APInt equality (this is the reason it works for the + // builtin integer attribute), but would be more fragile (and we'd store + // duplicate data). + let genStorageClass = false; + + let builders = [ + AttrBuilder<(ins "llvm::StringRef":$value)>, + AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>, + ]; + + let extraClassDeclaration = [{ + /// Return the bit-vector constant as a SMT-LIB formatted string. + std::string getValueAsString(bool prefix = true) const; + }]; +} + +#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td new file mode 100644 index 0000000000000..b6ca34e142d82 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td @@ -0,0 +1,255 @@ +//===- SMTBitVectorOps.td - SMT bit-vector dialect ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD +#define MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/Dialect/SMT/IR/SMTAttributes.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SMTBVOp traits = []> : + Op; + +def BVConstantOp : SMTBVOp<"constant", [ + Pure, + ConstantLike, + FirstAttrDerivedResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = "produce a constant bit-vector"; + let description = [{ + This operation produces an SSA value equal to the bit-vector constant + specified by the 'value' attribute. + Refer to the `BitVectorAttr` documentation for more information about + the semantics of bit-vector constants, their format, and associated sort. + The result type always matches the attribute's type. + + Examples: + ```mlir + %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> + %c5_bv4 = smt.bv.constant #smt.bv<5> : !smt.bv<4> + ``` + }]; + + let arguments = (ins BitVectorAttr:$value); + let results = (outs BitVectorType:$result); + + let assemblyFormat = "qualified($value) attr-dict"; + + let builders = [ + OpBuilder<(ins "const llvm::APInt &":$value), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value)); + }]>, + OpBuilder<(ins "uint64_t":$value, "unsigned":$width), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value, width)); + }]>, + ]; + + let hasFolder = true; +} + +class BVArithmeticOrBitwiseOp : + SMTBVOp { + let summary = "equivalent to bv" # mnemonic # " in SMT-LIB"; + let description = "This operation performs " # desc # [{. The semantics are + equivalent to the `bv}] # mnemonic # [{` operator defined in the SMT-LIB 2.6 + standard. More precisely in the [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2) + and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2) + describing closed quantifier-free formulas over the theory of fixed-size + bit-vectors. + }]; + + let results = (outs BitVectorType:$result); +} + +class BinaryBVOp : + BVArithmeticOrBitwiseOp { + let arguments = (ins BitVectorType:$lhs, BitVectorType:$rhs); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))"; +} + +class UnaryBVOp : + BVArithmeticOrBitwiseOp { + let arguments = (ins BitVectorType:$input); + let assemblyFormat = "$input attr-dict `:` qualified(type($result))"; +} + +def BVNotOp : UnaryBVOp<"not", "bitwise negation">; +def BVNegOp : UnaryBVOp<"neg", "two's complement unary minus">; + +def BVAndOp : BinaryBVOp<"and", "bitwise AND">; +def BVOrOp : BinaryBVOp<"or", "bitwise OR">; +def BVXOrOp : BinaryBVOp<"xor", "bitwise exclusive OR">; + +def BVAddOp : BinaryBVOp<"add", "addition">; +def BVMulOp : BinaryBVOp<"mul", "multiplication">; +def BVUDivOp : BinaryBVOp<"udiv", "unsigned division (rounded towards zero)">; +def BVSDivOp : BinaryBVOp<"sdiv", "two's complement signed division">; +def BVURemOp : BinaryBVOp<"urem", "unsigned remainder">; +def BVSRemOp : BinaryBVOp<"srem", + "two's complement signed remainder (sign follows dividend)">; +def BVSModOp : BinaryBVOp<"smod", + "two's complement signed remainder (sign follows divisor)">; +def BVShlOp : BinaryBVOp<"shl", "shift left">; +def BVLShrOp : BinaryBVOp<"lshr", "logical shift right">; +def BVAShrOp : BinaryBVOp<"ashr", "arithmetic shift right">; + +def PredicateSLT : I64EnumAttrCase<"slt", 0>; +def PredicateSLE : I64EnumAttrCase<"sle", 1>; +def PredicateSGT : I64EnumAttrCase<"sgt", 2>; +def PredicateSGE : I64EnumAttrCase<"sge", 3>; +def PredicateULT : I64EnumAttrCase<"ult", 4>; +def PredicateULE : I64EnumAttrCase<"ule", 5>; +def PredicateUGT : I64EnumAttrCase<"ugt", 6>; +def PredicateUGE : I64EnumAttrCase<"uge", 7>; +let cppNamespace = "mlir::smt" in +def BVCmpPredicate : I64EnumAttr< + "BVCmpPredicate", + "smt bit-vector comparison predicate", + [PredicateSLT, PredicateSLE, PredicateSGT, PredicateSGE, + PredicateULT, PredicateULE, PredicateUGT, PredicateUGE]>; + +def BVCmpOp : SMTBVOp<"cmp", [Pure, SameTypeOperands]> { + let summary = "compare bit-vectors interpreted as signed or unsigned"; + let description = [{ + This operation compares bit-vector values, interpreting them as signed or + unsigned values depending on the predicate. The semantics are equivalent to + the `bvslt`, `bvsle`, `bvsgt`, `bvsge`, `bvult`, `bvule`, `bvugt`, or + `bvuge` operator defined in the SMT-LIB 2.6 standard depending on the + specified predicate. More precisely in the + [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2) + and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2) + describing closed quantifier-free formulas over the theory of fixed-size + bit-vectors. + }]; + + let arguments = (ins BVCmpPredicate:$pred, + BitVectorType:$lhs, + BitVectorType:$rhs); + let results = (outs BoolType:$result); + + let assemblyFormat = [{ + $pred $lhs `,` $rhs attr-dict `:` qualified(type($lhs)) + }]; +} + +def ConcatOp : SMTBVOp<"concat", [ + Pure, + DeclareOpInterfaceMethods +]> { + let summary = "bit-vector concatenation"; + let description = [{ + This operation concatenates bit-vector values with semantics equivalent to + the `concat` operator defined in the SMT-LIB 2.6 standard. More precisely in + the [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2) + and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2) + describing closed quantifier-free formulas over the theory of fixed-size + bit-vectors. + + Note that the following equivalences hold: + * `smt.bv.concat %a, %b : !smt.bv<4>, !smt.bv<4>` is equivalent to + `(concat a b)` in SMT-LIB + * `(= (concat #xf #x0) #xf0)` + }]; + + let arguments = (ins BitVectorType:$lhs, BitVectorType:$rhs); + let results = (outs BitVectorType:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type(operands))"; +} + +def ExtractOp : SMTBVOp<"extract", [Pure]> { + let summary = "bit-vector extraction"; + let description = [{ + This operation extracts the range of bits starting at the 'lowBit' index + (inclusive) up to the 'lowBit' + result-width index (exclusive). The + semantics are equivalent to the `extract` operator defined in the SMT-LIB + 2.6 standard. More precisely in the + [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2) + and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2) + describing closed quantifier-free formulas over the theory of fixed-size + bit-vectors. + + Note that `smt.bv.extract %bv from 2 : (!smt.bv<32>) -> !smt.bv<16>` is + equivalent to `((_ extract 17 2) bv)`, i.e., the SMT-LIB operator takes the + low and high indices where both are inclusive. The following equivalence + holds: `(= ((_ extract 3 0) #x0f) #xf)` + }]; + + let arguments = (ins I32Attr:$lowBit, BitVectorType:$input); + let results = (outs BitVectorType:$result); + + let assemblyFormat = [{ + $input `from` $lowBit attr-dict `:` functional-type($input, $result) + }]; + + let hasVerifier = true; +} + +def RepeatOp : SMTBVOp<"repeat", [Pure]> { + let summary = "repeated bit-vector concatenation of one value"; + let description = [{ + This operation is a shorthand for repeated concatenation of the same + bit-vector value, i.e., + ```mlir + smt.bv.repeat 5 times %a : !smt.bv<4> + // is the same as + %0 = smt.bv.repeat 4 times %a : !smt.bv<4> + smt.bv.concat %a, %0 : !smt.bv<4>, !smt.bv<16> + // or also + %0 = smt.bv.repeat 4 times %a : !smt.bv<4> + smt.bv.concat %0, %a : !smt.bv<16>, !smt.bv<4> + ``` + + The semantics are equivalent to the `repeat` operator defined in the SMT-LIB + 2.6 standard. More precisely in the + [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2) + and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2) + describing closed quantifier-free formulas over the theory of fixed-size + bit-vectors. + }]; + + let arguments = (ins BitVectorType:$input); + let results = (outs BitVectorType:$result); + + let hasCustomAssemblyFormat = true; + let hasVerifier = true; + + let builders = [ + OpBuilder<(ins "unsigned":$count, "mlir::Value":$input)>, + ]; + + let extraClassDeclaration = [{ + /// Get the number of times the input operand is repeated. + unsigned getCount(); + }]; +} + +def BV2IntOp : SMTOp<"bv2int", [Pure]> { + let summary = "Convert an SMT bit-vector to an SMT integer."; + let description = [{ + Create an integer from the bit-vector argument `input`. If `is_signed` is + present, the bit-vector is treated as two's complement signed. Otherwise, + it is treated as an unsigned integer in the range [0..2^N-1], where N is + the number of bits in `input`. + }]; + let arguments = (ins BitVectorType:$input, UnitAttr:$is_signed); + let results = (outs IntType:$result); + let assemblyFormat = [{$input (`signed` $is_signed^)? attr-dict `:` + qualified(type($input))}]; +} + +#endif // MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h new file mode 100644 index 0000000000000..e808583a9e593 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h @@ -0,0 +1,20 @@ +//===- SMTDialect.h - SMT dialect definition --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTDIALECT_H +#define MLIR_DIALECT_SMT_SMTDIALECT_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LLVM.h" + +// Pull in the dialect definition. +#include "mlir/Dialect/SMT/IR/SMTDialect.h.inc" +#include "mlir/Dialect/SMT/IR/SMTEnums.h.inc" + +#endif // MLIR_DIALECT_SMT_SMTDIALECT_H diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td new file mode 100644 index 0000000000000..4b74187e85b87 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td @@ -0,0 +1,30 @@ +//===- SMTDialect.td - SMT dialect definition --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTDIALECT_TD +#define MLIR_DIALECT_SMT_SMTDIALECT_TD + +include "mlir/IR/DialectBase.td" + +def SMTDialect : Dialect { + let name = "smt"; + let summary = "a dialect that models satisfiability modulo theories"; + let cppNamespace = "mlir::smt"; + + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; + + let hasConstantMaterializer = 1; + + let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + }]; +} + +#endif // MLIR_DIALECT_SMT_SMTDIALECT_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td new file mode 100644 index 0000000000000..6606c9608ef55 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td @@ -0,0 +1,137 @@ +//===- SMTIntOps.td - SMT dialect int theory operations ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTINTOPS_TD +#define MLIR_DIALECT_SMT_SMTINTOPS_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/Dialect/SMT/IR/SMTAttributes.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SMTIntOp traits = []> : + SMTOp<"int." # mnemonic, traits>; + +def IntConstantOp : SMTIntOp<"constant", [ + Pure, + ConstantLike, + DeclareOpInterfaceMethods, +]> { + let summary = "produce a constant (infinite-precision) integer"; + let description = [{ + This operation represents (infinite-precision) integer literals of the `Int` + sort. The set of values for the sort `Int` consists of all numerals and + all terms of the form `-n`where n is a numeral other than 0. For more + information refer to the + [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + + let arguments = (ins APIntAttr:$value); + let results = (outs IntType:$result); + + let hasCustomAssemblyFormat = true; + let hasFolder = true; +} + +class VariadicIntOp : SMTIntOp { + let description = [{ + This operation represents (infinite-precision) }] # summary # [{. + The semantics are equivalent to the corresponding operator described in + the [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs IntType:$result); + let assemblyFormat = "$inputs attr-dict"; + + let builders = [ + OpBuilder<(ins "mlir::ValueRange":$inputs), [{ + build($_builder, $_state, $_builder.getType(), inputs); + }]>, + ]; +} + +class BinaryIntOp : SMTIntOp { + let description = [{ + This operation represents (infinite-precision) }] # summary # [{. + The semantics are equivalent to the corresponding operator described in + the [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + + let arguments = (ins IntType:$lhs, IntType:$rhs); + let results = (outs IntType:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict"; +} + +def IntAbsOp : SMTIntOp<"abs", [Pure]> { + let summary = "the absolute value of an Int"; + let description = [{ + This operation represents the absolute value function for the `Int` sort. + The semantics are equivalent to the `abs` operator as described in the + [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + + let arguments = (ins IntType:$input); + let results = (outs IntType:$result); + let assemblyFormat = "$input attr-dict"; +} + +def IntAddOp : VariadicIntOp<"add"> { let summary = "integer addition"; } +def IntMulOp : VariadicIntOp<"mul"> { let summary = "integer multiplication"; } +def IntSubOp : BinaryIntOp<"sub"> { let summary = "integer subtraction"; } +def IntDivOp : BinaryIntOp<"div"> { let summary = "integer division"; } +def IntModOp : BinaryIntOp<"mod"> { let summary = "integer remainder"; } + +def IntPredicateLT : I64EnumAttrCase<"lt", 0>; +def IntPredicateLE : I64EnumAttrCase<"le", 1>; +def IntPredicateGT : I64EnumAttrCase<"gt", 2>; +def IntPredicateGE : I64EnumAttrCase<"ge", 3>; +let cppNamespace = "mlir::smt" in +def IntPredicate : I64EnumAttr< + "IntPredicate", + "smt comparison predicate for integers", + [IntPredicateLT, IntPredicateLE, IntPredicateGT, IntPredicateGE]>; + +def IntCmpOp : SMTIntOp<"cmp", [Pure]> { + let summary = "integer comparison"; + let description = [{ + This operation represents the comparison of (infinite-precision) integers. + The semantics are equivalent to the `<= (le)`, `< (lt)`, `>= (ge)`, or + `> (gt)` operator depending on the predicate (indicated in parentheses) as + described in the + [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + + let arguments = (ins IntPredicate:$pred, IntType:$lhs, IntType:$rhs); + let results = (outs BoolType:$result); + let assemblyFormat = "$pred $lhs `,` $rhs attr-dict"; +} + +def Int2BVOp : SMTOp<"int2bv", [Pure]> { + let summary = "Convert an integer to an inferred-width bitvector."; + let description = [{ + Designed to lower directly to an operation of the same name in Z3. The Z3 + C API describes the semantics as follows: + Create an n bit bit-vector from the integer argument t1. + The resulting bit-vector has n bits, where the i'th bit (counting from 0 + to n-1) is 1 if (t1 div 2^i) mod 2 is 1. + The node t1 must have integer sort. + }]; + let arguments = (ins IntType:$input); + let results = (outs BitVectorType:$result); + let assemblyFormat = "$input attr-dict `:` qualified(type($result))"; +} + +#endif // MLIR_DIALECT_SMT_SMTINTOPS_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h new file mode 100644 index 0000000000000..859566ec6dbdb --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h @@ -0,0 +1,25 @@ +//===- SMTOps.h - SMT dialect operations ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTOPS_H +#define MLIR_DIALECT_SMT_SMTOPS_H + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/SMT/IR/SMTAttributes.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/SMT/IR/SMT.h.inc" + +#endif // MLIR_DIALECT_SMT_SMTOPS_H diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td new file mode 100644 index 0000000000000..18a1483f1dab1 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td @@ -0,0 +1,477 @@ +//===- SMTOps.td - SMT dialect operations ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTOPS_TD +#define MLIR_DIALECT_SMT_SMTOPS_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/Dialect/SMT/IR/SMTAttributes.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + +class SMTOp traits = []> : + Op; + +def DeclareFunOp : SMTOp<"declare_fun", [ + DeclareOpInterfaceMethods +]> { + let summary = "declare a symbolic value of a given sort"; + let description = [{ + This operation declares a symbolic value just as the `declare-const` and + `declare-func` statements in SMT-LIB 2.6. The result type determines the SMT + sort of the symbolic value. The returned value can then be used to refer to + the symbolic value instead of using the identifier like in SMT-LIB. + + The optionally provided string will be used as a prefix for the newly + generated identifier (useful for easier readability when exporting to + SMT-LIB). Each `declare` will always provide a unique new symbolic value + even if the identifier strings are the same. + + Note that there does not exist a separate operation equivalent to + SMT-LIBs `define-fun` since + ``` + (define-fun f (a Int) Int (-a)) + ``` + is only syntactic sugar for + ``` + %f = smt.declare_fun : !smt.func<(!smt.int) !smt.int> + %0 = smt.forall { + ^bb0(%arg0: !smt.int): + %1 = smt.apply_func %f(%arg0) : !smt.func<(!smt.int) !smt.int> + %2 = smt.int.neg %arg0 + %3 = smt.eq %1, %2 : !smt.int + smt.yield %3 : !smt.bool + } + smt.assert %0 + ``` + + Note that this operation cannot be marked as Pure since two operations (even + with the same identifier string) could then be CSEd, leading to incorrect + behavior. + }]; + + let arguments = (ins OptionalAttr:$namePrefix); + let results = (outs Res:$result); + + let assemblyFormat = [{ + ($namePrefix^)? attr-dict `:` qualified(type($result)) + }]; + + let builders = [ + OpBuilder<(ins "mlir::Type":$type), [{ + build($_builder, $_state, type, nullptr); + }]> + ]; +} + +def BoolConstantOp : SMTOp<"constant", [ + Pure, + ConstantLike, + DeclareOpInterfaceMethods, +]> { + let summary = "Produce a constant boolean"; + let description = [{ + Produces the constant expressions 'true' and 'false' as described in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB + Standard 2.6. + }]; + + let arguments = (ins BoolAttr:$value); + let results = (outs BoolType:$result); + let assemblyFormat = "$value attr-dict"; + + let hasFolder = true; +} + +def SolverOp : SMTOp<"solver", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"smt::YieldOp">, +]> { + let summary = "create a solver instance within a lifespan"; + let description = [{ + This operation defines an SMT context with a solver instance. SMT operations + are only valid when being executed between the start and end of the region + of this operation. Any invocation outside is undefined. However, they do not + have to be direct children of this operation. For example, it is allowed to + have SMT operations in a `func.func` which is only called from within this + region. No SMT value may enter or exit the lifespan of this region (such + that no value created from another SMT context can be used in this scope and + the solver can deallocate all state required to keep track of SMT values at + the end). + + As a result, the region is comparable to an entire SMT-LIB script, but + allows for concrete operations and control-flow. Concrete values may be + passed in and returned to influence the computations after the `smt.solver` + operation. + + Example: + ```mlir + %0:2 = smt.solver (%in) {smt.some_attr} : (i8) -> (i8, i32) { + ^bb0(%arg0: i8): + %c = smt.declare_fun "c" : !smt.bool + smt.assert %c + %1 = smt.check sat { + %c1_i32 = arith.constant 1 : i32 + smt.yield %c1_i32 : i32 + } unknown { + %c0_i32 = arith.constant 0 : i32 + smt.yield %c0_i32 : i32 + } unsat { + %c-1_i32 = arith.constant -1 : i32 + smt.yield %c-1_i32 : i32 + } -> i32 + smt.yield %arg0, %1 : i8, i32 + } + ``` + + TODO: solver configuration attributes + }]; + + let arguments = (ins Variadic:$inputs); + let regions = (region SizedRegion<1>:$bodyRegion); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + `(` $inputs `)` attr-dict `:` functional-type($inputs, $results) $bodyRegion + }]; + + let hasRegionVerifier = true; +} + +def SetLogicOp : SMTOp<"set_logic", [ + HasParent<"smt::SolverOp">, +]> { + let summary = "set the logic for the SMT solver"; + let arguments = (ins StrAttr:$logic); + let assemblyFormat = "$logic attr-dict"; +} + +def AssertOp : SMTOp<"assert", []> { + let summary = "assert that a boolean expression holds"; + let arguments = (ins BoolType:$input); + let assemblyFormat = "$input attr-dict"; +} + +def ResetOp : SMTOp<"reset", []> { + let summary = "reset the solver"; + let assemblyFormat = "attr-dict"; +} + +def PushOp : SMTOp<"push", []> { + let summary = "push a given number of levels onto the assertion stack"; + let arguments = (ins ConfinedAttr:$count); + let assemblyFormat = "$count attr-dict"; +} + +def PopOp : SMTOp<"pop", []> { + let summary = "pop a given number of levels from the assertion stack"; + let arguments = (ins ConfinedAttr:$count); + let assemblyFormat = "$count attr-dict"; +} + +def CheckOp : SMTOp<"check", [ + NoRegionArguments, + SingleBlockImplicitTerminator<"smt::YieldOp">, +]> { + let summary = "check if the current set of assertions is satisfiable"; + let description = [{ + This operation checks if all the assertions in the solver defined by the + nearest ancestor operation of type `smt.solver` are consistent. The outcome + an be 'satisfiable', 'unknown', or 'unsatisfiable' and the corresponding + region will be executed. It is the corresponding construct to the + `check-sat` in SMT-LIB. + + Example: + ```mlir + %0 = smt.check sat { + %c1_i32 = arith.constant 1 : i32 + smt.yield %c1_i32 : i32 + } unknown { + %c0_i32 = arith.constant 0 : i32 + smt.yield %c0_i32 : i32 + } unsat { + %c-1_i32 = arith.constant -1 : i32 + smt.yield %c-1_i32 : i32 + } -> i32 + ``` + }]; + + let regions = (region SizedRegion<1>:$satRegion, + SizedRegion<1>:$unknownRegion, + SizedRegion<1>:$unsatRegion); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + attr-dict `sat` $satRegion `unknown` $unknownRegion `unsat` $unsatRegion + (`->` qualified(type($results))^ )? + }]; + + let hasRegionVerifier = true; +} + +def YieldOp : SMTOp<"yield", [ + Pure, + Terminator, + ReturnLike, + ParentOneOf<["smt::SolverOp", "smt::CheckOp", + "smt::ForallOp", "smt::ExistsOp"]>, +]> { + let summary = "terminator operation for various regions of SMT operations"; + let arguments = (ins Variadic:$values); + let assemblyFormat = "($values^ `:` qualified(type($values)))? attr-dict"; + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; +} + +def ApplyFuncOp : SMTOp<"apply_func", [ + Pure, + TypesMatchWith<"summary", "func", "result", + "cast($_self).getRangeType()">, + RangedTypesMatchWith<"summary", "func", "args", + "cast($_self).getDomainTypes()"> +]> { + let summary = "apply a function"; + let description = [{ + This operation performs a function application as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + }]; + + let arguments = (ins SMTFuncType:$func, + Variadic:$args); + let results = (outs AnyNonFuncSMTType:$result); + + let assemblyFormat = [{ + $func `(` $args `)` attr-dict `:` qualified(type($func)) + }]; +} + +def EqOp : SMTOp<"eq", [Pure, SameTypeOperands]> { + let summary = "returns true iff all operands are identical"; + let description = [{ + This operation compares the operands and returns true iff all operands are + identical. The semantics are equivalent to the `=` operator defined in the + SMT-LIB Standard 2.6 in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2). + + Any SMT sort/type is allowed for the operands and it supports a variadic + number of operands, but requires at least two. This is because the `=` + operator is annotated with `:chainable` which means that `= a b c d` is + equivalent to `and (= a b) (= b c) (= c d)` where `and` is annotated + `:left-assoc`, i.e., it can be further rewritten to + `and (and (= a b) (= b c)) (= c d)`. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs BoolType:$result); + + let builders = [ + OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{ + build($_builder, $_state, ValueRange{lhs, rhs}); + }]> + ]; + + let hasCustomAssemblyFormat = true; + let hasVerifier = true; +} + +def DistinctOp : SMTOp<"distinct", [Pure, SameTypeOperands]> { + let summary = "returns true iff all operands are not identical to any other"; + let description = [{ + This operation compares the operands and returns true iff all operands are + not identical to any of the other operands. The semantics are equivalent to + the `distinct` operator defined in the SMT-LIB Standard 2.6 in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2). + + Any SMT sort/type is allowed for the operands and it supports a variadic + number of operands, but requires at least two. This is because the + `distinct` operator is annotated with `:pairwise` which means that + `distinct a b c d` is equivalent to + ``` + and (distinct a b) (distinct a c) (distinct a d) + (distinct b c) (distinct b d) + (distinct c d) + ``` + where `and` is annotated `:left-assoc`, i.e., it can be further rewritten to + ``` + (and (and (and (and (and (distinct a b) + (distinct a c)) + (distinct a d)) + (distinct b c)) + (distinct b d)) + (distinct c d) + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs BoolType:$result); + + let builders = [ + OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{ + build($_builder, $_state, ValueRange{lhs, rhs}); + }]> + ]; + + let hasCustomAssemblyFormat = true; + let hasVerifier = true; +} + +def IteOp : SMTOp<"ite", [ + Pure, + AllTypesMatch<["thenValue", "elseValue", "result"]> +]> { + let summary = "an if-then-else function"; + let description = [{ + This operation returns its second operand or its third operand depending on + whether its first operand is true or not. The semantics are equivalent to + the `ite` operator defined in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB + 2.6 standard. + }]; + + let arguments = (ins BoolType:$cond, + AnySMTType:$thenValue, + AnySMTType:$elseValue); + let results = (outs AnySMTType:$result); + + let assemblyFormat = [{ + $cond `,` $thenValue `,` $elseValue attr-dict `:` qualified(type($result)) + }]; +} + +def NotOp : SMTOp<"not", [Pure]> { + let summary = "a boolean negation"; + let description = [{ + This operation performs a boolean negation. The semantics are equivalent to + the 'not' operator in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB + Standard 2.6. + }]; + + let arguments = (ins BoolType:$input); + let results = (outs BoolType:$result); + let assemblyFormat = "$input attr-dict"; +} + +class VariadicBoolOp : SMTOp { + let summary = desc; + let description = "This operation performs " # desc # [{. + The semantics are equivalent to the '}] # mnemonic # [{' operator in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2). + of the SMT-LIB Standard 2.6. + + It supports a variadic number of operands, but requires at least two. + This is because the operator is annotated with the `:left-assoc` attribute + which means that `op a b c` is equivalent to `(op (op a b) c)`. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs BoolType:$result); + let assemblyFormat = "$inputs attr-dict"; + + let builders = [ + OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{ + build($_builder, $_state, ValueRange{lhs, rhs}); + }]> + ]; +} + +def AndOp : VariadicBoolOp<"and", "a boolean conjunction">; +def OrOp : VariadicBoolOp<"or", "a boolean disjunction">; +def XOrOp : VariadicBoolOp<"xor", "a boolean exclusive OR">; + +def ImpliesOp : SMTOp<"implies", [Pure]> { + let summary = "boolean implication"; + let description = [{ + This operation performs a boolean implication. The semantics are equivalent + to the '=>' operator in the + [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB + Standard 2.6. + }]; + + let arguments = (ins BoolType:$lhs, BoolType:$rhs); + let results = (outs BoolType:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict"; +} + +class QuantifierOp : SMTOp, +]> { + let description = [{ + This operation represents the }] # summary # [{ as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + + The operation specifies the name prefixes (as an optional attribute) and + types (as the types of the block arguments of the regions) of bound + variables that may be used in the 'body' of the operation. If a 'patterns' + region is specified, the block arguments must match the ones of the 'body' + region and (other than there) must be used at least once in the 'patterns' + region. It may also not contain any operations that bind variables, such as + quantifiers. While the 'body' region must always yield exactly one + `!smt.bool`-typed value, the 'patterns' region can yield an arbitrary number + (but at least one) of SMT values. + + The bound variables can be any SMT type except of functions, since SMT only + supports first-order logic. + + The 'no_patterns' attribute is only allowed when no 'patterns' region is + specified and forbids the solver to generate and use patterns for this + quantifier. + + The 'weight' attribute indicates the importance of this quantifier being + instantiated compared to other quantifiers that may be present. The default + value is zero. + + Both the 'no_patterns' and 'weight' attributes are annotations to the + quantifiers body term. Annotations and attributes are described in the + standard in sections 3.4, and 3.6 (specifically 3.6.5). SMT-LIB allows + adding custom attributes to provide solvers with additional metadata, e.g., + hints such as above mentioned attributes. They are not part of the standard + themselves, but supported by common SMT solvers (e.g., Z3). + }]; + + let arguments = (ins DefaultValuedAttr:$weight, + UnitAttr:$noPattern, + OptionalAttr:$boundVarNames); + let regions = (region SizedRegion<1>:$body, + VariadicRegion>:$patterns); + let results = (outs BoolType:$result); + + let builders = [ + OpBuilder<(ins + "TypeRange":$boundVarTypes, + "function_ref":$bodyBuilder, + CArg<"std::optional>", "std::nullopt">:$boundVarNames, + CArg<"function_ref", + "{}">:$patternBuilder, + CArg<"uint32_t", "0">:$weight, + CArg<"bool", "false">:$noPattern)> + ]; + let skipDefaultBuilders = true; + + let assemblyFormat = [{ + ($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)? + attr-dict-with-keyword $body (`patterns` $patterns^)? + }]; + + let hasVerifier = true; + let hasRegionVerifier = true; +} + +def ForallOp : QuantifierOp<"forall"> { let summary = "forall quantifier"; } +def ExistsOp : QuantifierOp<"exists"> { let summary = "exists quantifier"; } + +#endif // MLIR_DIALECT_SMT_SMTOPS_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h new file mode 100644 index 0000000000000..4db28f7a07a41 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h @@ -0,0 +1,30 @@ +//===- SMTTypes.h - SMT dialect types ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTTYPES_H +#define MLIR_DIALECT_SMT_SMTTYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/SMT/IR/SMTTypes.h.inc" + +namespace mlir { +namespace smt { + +/// Returns whether the given type is an SMT value type. +bool isAnySMTValueType(mlir::Type type); + +/// Returns whether the given type is an SMT value type (excluding functions). +bool isAnyNonFuncSMTValueType(mlir::Type type); + +} // namespace smt +} // namespace mlir + +#endif // MLIR_DIALECT_SMT_SMTTYPES_H diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td new file mode 100644 index 0000000000000..3032900b52178 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td @@ -0,0 +1,145 @@ +//===- SMTTypes.td - SMT dialect types ---------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTTYPES_TD +#define MLIR_DIALECT_SMT_SMTTYPES_TD + +include "mlir/Dialect/SMT/IR/SMTDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class SMTTypeDef : TypeDef { } + +def BoolType : SMTTypeDef<"Bool"> { + let mnemonic = "bool"; + let assemblyFormat = ""; +} + +def IntType : SMTTypeDef<"Int"> { + let mnemonic = "int"; + let description = [{ + This type represents the `Int` sort as described in the + [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the + SMT-LIB 2.6 standard. + }]; + let assemblyFormat = ""; +} + +def BitVectorType : SMTTypeDef<"BitVector"> { + let mnemonic = "bv"; + let description = [{ + This type represents the `(_ BitVec width)` sort as described in the + [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The bit-width must be strictly greater than zero. + }]; + + let parameters = (ins "int64_t":$width); + let assemblyFormat = "`<` $width `>`"; + + let genVerifyDecl = true; +} + +def ArrayType : SMTTypeDef<"Array"> { + let mnemonic = "array"; + let description = [{ + This type represents the `(Array X Y)` sort, where X and Y are any + sort/type, as described in the + [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of + the SMT-LIB standard 2.6. + }]; + + let parameters = (ins "mlir::Type":$domainType, "mlir::Type":$rangeType); + let assemblyFormat = "`<` `[` $domainType `->` $rangeType `]` `>`"; + + let genVerifyDecl = true; +} + +def SMTFuncType : SMTTypeDef<"SMTFunc"> { + let mnemonic = "func"; + let description = [{ + This type represents the SMT function sort as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + + A function in SMT can have an arbitrary domain size, but always has exactly + one range sort. + + Since SMT only supports first-order logic, it is not possible to nest + function types. + + Example: `!smt.func<(!smt.bool, !smt.int) !smt.bool>` is equivalent to + `((Bool Int) Bool)` in SMT-LIB. + }]; + + let parameters = (ins + ArrayRefParameter<"mlir::Type", "domain types">:$domainTypes, + "mlir::Type":$rangeType + ); + + // Note: We are not printing the parentheses when no domain type is present + // because the default MLIR parser thinks it is a builtin function type + // otherwise. + let assemblyFormat = "`<` `(` $domainTypes `)` ` ` $rangeType `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$domainTypes, + "mlir::Type":$rangeType), [{ + return $_get(rangeType.getContext(), domainTypes, rangeType); + }]>, + TypeBuilderWithInferredContext<(ins "mlir::Type":$rangeType), [{ + return $_get(rangeType.getContext(), + llvm::ArrayRef{}, rangeType); + }]> + ]; + + let genVerifyDecl = true; +} + +def SortType : SMTTypeDef<"Sort"> { + let mnemonic = "sort"; + let description = [{ + This type represents uninterpreted sorts. The usage of a type like + `!smt.sort<"sort_name"[!smt.bool, !smt.sort<"other_sort">]>` implies a + `declare-sort sort_name 2` and a `declare-sort other_sort 0` in SMT-LIB. + This type represents concrete use-sites of such declared sorts, in this + particular case it would be equivalent to `(sort_name Bool other_sort)` in + SMT-LIB. More details about the semantics can be found in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + }]; + + let parameters = (ins + "mlir::StringAttr":$identifier, + OptionalArrayRefParameter<"mlir::Type", "sort parameters">:$sortParams + ); + + let assemblyFormat = "`<` $identifier (`[` $sortParams^ `]`)? `>`"; + + let builders = [ + TypeBuilder<(ins "llvm::StringRef":$identifier, + "llvm::ArrayRef":$sortParams), [{ + return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier), + sortParams); + }]>, + TypeBuilder<(ins "llvm::StringRef":$identifier), [{ + return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier), + llvm::ArrayRef{}); + }]>, + ]; + + let genVerifyDecl = true; +} + +def AnySMTType : Type, + "any SMT value type">; +def AnyNonFuncSMTType : Type, + "any non-function SMT value type">; +def AnyNonSMTType : Type, "any non-smt type">; + +#endif // MLIR_DIALECT_SMT_SMTTYPES_TD diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h b/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h new file mode 100644 index 0000000000000..38fad21019158 --- /dev/null +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h @@ -0,0 +1,201 @@ +//===- SMTVisitors.h - SMT Dialect Visitors ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines visitors that make it easier to work with the SMT IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SMT_SMTVISITORS_H +#define MLIR_DIALECT_SMT_SMTVISITORS_H + +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace smt { + +/// This helps visit SMT nodes. +template +class SMTOpVisitor { +public: + ResultType dispatchSMTOpVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case< + // Constants + BoolConstantOp, IntConstantOp, BVConstantOp, + // Bit-vector arithmetic + BVNegOp, BVAddOp, BVMulOp, BVURemOp, BVSRemOp, BVSModOp, BVShlOp, + BVLShrOp, BVAShrOp, BVUDivOp, BVSDivOp, + // Bit-vector bitwise + BVNotOp, BVAndOp, BVOrOp, BVXOrOp, + // Other bit-vector ops + ConcatOp, ExtractOp, RepeatOp, BVCmpOp, BV2IntOp, + // Int arithmetic + IntAddOp, IntMulOp, IntSubOp, IntDivOp, IntModOp, IntCmpOp, + Int2BVOp, + // Core Ops + EqOp, DistinctOp, IteOp, + // Variable/symbol declaration + DeclareFunOp, ApplyFuncOp, + // solver interaction + SolverOp, AssertOp, ResetOp, PushOp, PopOp, CheckOp, SetLogicOp, + // Boolean logic + NotOp, AndOp, OrOp, XOrOp, ImpliesOp, + // Arrays + ArrayStoreOp, ArraySelectOp, ArrayBroadcastOp, + // Quantifiers + ForallOp, ExistsOp, YieldOp>([&](auto expr) -> ResultType { + return thisCast->visitSMTOp(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidSMTOp(op, args...); + }); + } + + /// This callback is invoked on any non-expression operations. + ResultType visitInvalidSMTOp(Operation *op, ExtraArgs... args) { + op->emitOpError("unknown SMT node"); + abort(); + } + + /// This callback is invoked on any SMT operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledSMTOp(Operation *op, ExtraArgs... args) { + return ResultType(); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitSMTOp(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##SMTOp(op, \ + args...); \ + } + + // Constants + HANDLE(BoolConstantOp, Unhandled); + HANDLE(IntConstantOp, Unhandled); + HANDLE(BVConstantOp, Unhandled); + + // Bit-vector arithmetic + HANDLE(BVNegOp, Unhandled); + HANDLE(BVAddOp, Unhandled); + HANDLE(BVMulOp, Unhandled); + HANDLE(BVURemOp, Unhandled); + HANDLE(BVSRemOp, Unhandled); + HANDLE(BVSModOp, Unhandled); + HANDLE(BVShlOp, Unhandled); + HANDLE(BVLShrOp, Unhandled); + HANDLE(BVAShrOp, Unhandled); + HANDLE(BVUDivOp, Unhandled); + HANDLE(BVSDivOp, Unhandled); + + // Bit-vector bitwise operations + HANDLE(BVNotOp, Unhandled); + HANDLE(BVAndOp, Unhandled); + HANDLE(BVOrOp, Unhandled); + HANDLE(BVXOrOp, Unhandled); + + // Other bit-vector operations + HANDLE(ConcatOp, Unhandled); + HANDLE(ExtractOp, Unhandled); + HANDLE(RepeatOp, Unhandled); + HANDLE(BVCmpOp, Unhandled); + HANDLE(BV2IntOp, Unhandled); + + // Int arithmetic + HANDLE(IntAddOp, Unhandled); + HANDLE(IntMulOp, Unhandled); + HANDLE(IntSubOp, Unhandled); + HANDLE(IntDivOp, Unhandled); + HANDLE(IntModOp, Unhandled); + + HANDLE(IntCmpOp, Unhandled); + HANDLE(Int2BVOp, Unhandled); + + HANDLE(EqOp, Unhandled); + HANDLE(DistinctOp, Unhandled); + HANDLE(IteOp, Unhandled); + + HANDLE(DeclareFunOp, Unhandled); + HANDLE(ApplyFuncOp, Unhandled); + + HANDLE(SolverOp, Unhandled); + HANDLE(AssertOp, Unhandled); + HANDLE(ResetOp, Unhandled); + HANDLE(PushOp, Unhandled); + HANDLE(PopOp, Unhandled); + HANDLE(CheckOp, Unhandled); + HANDLE(SetLogicOp, Unhandled); + + // Boolean logic operations + HANDLE(NotOp, Unhandled); + HANDLE(AndOp, Unhandled); + HANDLE(OrOp, Unhandled); + HANDLE(XOrOp, Unhandled); + HANDLE(ImpliesOp, Unhandled); + + // Array operations + HANDLE(ArrayStoreOp, Unhandled); + HANDLE(ArraySelectOp, Unhandled); + HANDLE(ArrayBroadcastOp, Unhandled); + + // Quantifier operations + HANDLE(ForallOp, Unhandled); + HANDLE(ExistsOp, Unhandled); + HANDLE(YieldOp, Unhandled); + +#undef HANDLE +}; + +/// This helps visit SMT types. +template +class SMTTypeVisitor { +public: + ResultType dispatchSMTTypeVisitor(Type type, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(type) + .template Case([&](auto expr) -> ResultType { + return thisCast->visitSMTType(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidSMTType(type, args...); + }); + } + + /// This callback is invoked on any non-expression types. + ResultType visitInvalidSMTType(Type type, ExtraArgs... args) { abort(); } + + /// This callback is invoked on any SMT type that are not + /// handled by the concrete visitor. + ResultType visitUnhandledSMTType(Type type, ExtraArgs... args) { + return ResultType(); + } + +#define HANDLE(TYPE, KIND) \ + ResultType visitSMTType(TYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##KIND##SMTType(op, \ + args...); \ + } + + HANDLE(BoolType, Unhandled); + HANDLE(IntegerType, Unhandled); + HANDLE(BitVectorType, Unhandled); + HANDLE(ArrayType, Unhandled); + HANDLE(SMTFuncType, Unhandled); + HANDLE(SortType, Unhandled); + +#undef HANDLE +}; + +} // namespace smt +} // namespace mlir + +#endif // MLIR_DIALECT_SMT_SMTVISITORS_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 33bc89279c08c..e83be7b40eded 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -73,6 +73,7 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" @@ -143,6 +144,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { ROCDL::ROCDLDialect, scf::SCFDialect, shape::ShapeDialect, + smt::SMTDialect, sparse_tensor::SparseTensorDialect, spirv::SPIRVDialect, tensor::TensorDialect, diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index dd8b292a87344..bb8dff47ab480 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -95,6 +95,7 @@ inline void registerAllPasses() { arm_sve::registerArmSVEPasses(); emitc::registerEmitCPasses(); xegpu::registerXeGPUPasses(); + registerConvertArithToSMTPass(); // Dialect pipelines bufferization::registerBufferizationPipelines(); diff --git a/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp b/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp new file mode 100644 index 0000000000000..6b8714a5a1c44 --- /dev/null +++ b/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp @@ -0,0 +1,351 @@ +//===- ArithToSMT.cpp +//------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToSMT/ArithToSMT.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARITHTOSMT +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Lower a arith::CmpIOp operation to a smt::BVCmpOp, smt::EqOp or +/// smt::DistinctOp +/// +struct CmpIOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getPredicate() == arith::CmpIPredicate::eq) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + if (adaptor.getPredicate() == arith::CmpIPredicate::ne) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + smt::BVCmpPredicate pred; + switch (adaptor.getPredicate()) { + case arith::CmpIPredicate::sge: + pred = smt::BVCmpPredicate::sge; + break; + case arith::CmpIPredicate::sgt: + pred = smt::BVCmpPredicate::sgt; + break; + case arith::CmpIPredicate::sle: + pred = smt::BVCmpPredicate::sle; + break; + case arith::CmpIPredicate::slt: + pred = smt::BVCmpPredicate::slt; + break; + case arith::CmpIPredicate::uge: + pred = smt::BVCmpPredicate::uge; + break; + case arith::CmpIPredicate::ugt: + pred = smt::BVCmpPredicate::ugt; + break; + case arith::CmpIPredicate::ule: + pred = smt::BVCmpPredicate::ule; + break; + case arith::CmpIPredicate::ult: + pred = smt::BVCmpPredicate::ult; + break; + default: + llvm_unreachable("all cases handled above"); + } + + rewriter.replaceOpWithNewOp(op, pred, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +/// Lower a arith::SubOp operation to an smt::BVNegOp + smt::BVAddOp +struct SubOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value negRhs = rewriter.create(op.getLoc(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), negRhs); + return success(); + } +}; + +/// Lower the SourceOp to the TargetOp one-to-one. +template +struct OneToOneOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::typeConverter->convertType( + op.getResult().getType()), + adaptor.getOperands()); + return success(); + } +}; + +struct CeilDivSIOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto numPlusDenom = rewriter.createOrFold( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + auto bitWidth = + llvm::cast(getElementTypeOrSelf(adaptor.getLhs())) + .getWidth(); + auto one = rewriter.create(op.getLoc(), 1, bitWidth); + auto numPlusDenomMinusOne = + rewriter.createOrFold(op.getLoc(), numPlusDenom, one); + rewriter.replaceOpWithNewOp(op, numPlusDenomMinusOne, + adaptor.getRhs()); + return success(); + } +}; + +/// Lower the SourceOp to the TargetOp special-casing if the second operand is +/// zero to return a new symbolic value. +template +struct DivisionOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto type = dyn_cast(adaptor.getRhs().getType()); + if (!type) + return failure(); + + auto resultType = OpConversionPattern::typeConverter->convertType( + op.getResult().getType()); + Value zero = + rewriter.create(loc, APInt(type.getWidth(), 0)); + Value isZero = rewriter.create(loc, adaptor.getRhs(), zero); + Value symbolicVal = rewriter.create(loc, resultType); + Value division = + rewriter.create(loc, resultType, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, isZero, symbolicVal, division); + return success(); + } +}; + +/// Converts an operation with a variadic number of operands to a chain of +/// binary operations assuming left-associativity of the operation. +template +struct VariadicToBinaryOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + ValueRange operands = adaptor.getOperands(); + if (operands.size() < 2) + return failure(); + + Value runner = operands[0]; + for (Value operand : operands.drop_front()) + runner = rewriter.create(op.getLoc(), runner, operand); + + rewriter.replaceOp(op, runner); + return success(); + } +}; + +/// Lower a arith::ConstantOp operation to smt::BVConstantOp +struct ArithConstantIntOpConversion + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto v = llvm::cast(adaptor.getValue()); + if (v.getValue().getBitWidth() < 1) + return rewriter.notifyMatchFailure(op.getLoc(), + "0-bit constants not supported"); + // TODO(max): signed/unsigned/signless semenatics + rewriter.replaceOpWithNewOp(op, v.getValue()); + return success(); + } +}; + +} // namespace + +void populateArithToSMTTypeConverter(TypeConverter &converter) { + // The semantics of the builtin integer at the MLIR core level is currently + // not very well defined. It is used for two-valued, four-valued, and possible + // other multi-valued logic. Here, we interpret it as two-valued for now. + // From a formal perspective, MLIR would ideally define its own types for + // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream + // the integer type also carries poison information (which we don't have in + // MLIR?). + converter.addConversion([](IntegerType type) -> std::optional { + if (type.getWidth() <= 0) + return std::nullopt; + return smt::BitVectorType::get(type.getContext(), type.getWidth()); + }); + + // Default target materialization to convert from illegal types to legal + // types, e.g., at the boundary of an inlined child block. + converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + return builder + .create(loc, resultType, inputs) + ->getResult(0); + }); + + // Convert a 'smt.bool'-typed value to a 'smt.bv'-typed value + converter.addTargetMaterialization( + [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + + if (!isa(inputs[0].getType())) + return Value(); + + unsigned width = resultType.getWidth(); + Value constZero = builder.create(loc, 0, width); + Value constOne = builder.create(loc, 1, width); + return builder.create(loc, inputs[0], constOne, constZero); + }); + + // Convert an unrealized conversion cast from 'smt.bool' to i1 + // into a direct conversion from 'smt.bool' to 'smt.bv<1>'. + converter.addTargetMaterialization( + [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1 || resultType.getWidth() != 1) + return Value(); + + auto intType = dyn_cast(inputs[0].getType()); + if (!intType || intType.getWidth() != 1) + return Value(); + + auto castOp = + inputs[0].getDefiningOp(); + if (!castOp || castOp.getInputs().size() != 1) + return Value(); + + if (!isa(castOp.getInputs()[0].getType())) + return Value(); + + Value constZero = builder.create(loc, 0, 1); + Value constOne = builder.create(loc, 1, 1); + return builder.create(loc, castOp.getInputs()[0], constOne, + constZero); + }); + + // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value + converter.addTargetMaterialization( + [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + + auto bvType = dyn_cast(inputs[0].getType()); + if (!bvType || bvType.getWidth() != 1) + return Value(); + + Value constOne = builder.create(loc, 1, 1); + return builder.create(loc, inputs[0], constOne); + }); + + // Default source materialization to convert from illegal types to legal + // types, e.g., at the boundary of an inlined child block. + converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + return builder + .create(loc, resultType, inputs) + ->getResult(0); + }); +} + +namespace { +struct ConvertArithToSMT + : public impl::ConvertArithToSMTBase { + using Base::Base; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + TypeConverter converter; + populateArithToSMTTypeConverter(converter); + patterns.clear(); + arith::populateArithToSMTConversionPatterns(converter, patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir::arith { +void populateArithToSMTConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add, + OneToOneOpConversion, + OneToOneOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, + VariadicToBinaryOpConversion, + VariadicToBinaryOpConversion, + VariadicToBinaryOpConversion, + VariadicToBinaryOpConversion, + VariadicToBinaryOpConversion>( + converter, patterns.getContext()); +} +} // namespace mlir::arith \ No newline at end of file diff --git a/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt b/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt new file mode 100644 index 0000000000000..ef9df95568cb4 --- /dev/null +++ b/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_conversion_library(MLIRCombToSMT + ArithToSMT.cpp + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRSMT + MLIRTransforms +) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index b6c21440c571c..78d0ffd382cce 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(ArithToAMDGPU) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) +add_subdirectory(ArithToSMT) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(ArmSMEToSCF) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 80b0ef068d96d..a473f2ff317c9 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -33,6 +33,7 @@ add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) +add_subdirectory(SMT) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Tensor) diff --git a/mlir/lib/Dialect/SMT/CMakeLists.txt b/mlir/lib/Dialect/SMT/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/lib/Dialect/SMT/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/SMT/IR/CMakeLists.txt b/mlir/lib/Dialect/SMT/IR/CMakeLists.txt new file mode 100644 index 0000000000000..e287613da9fd0 --- /dev/null +++ b/mlir/lib/Dialect/SMT/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +add_mlir_dialect_library(MLIRSMT + SMTAttributes.cpp + SMTDialect.cpp + SMTOps.cpp + SMTTypes.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SMT + + DEPENDS + MLIRSMTAttrIncGen + MLIRSMTEnumsIncGen + MLIRSMTIncGen + + LINK_COMPONENTS + Support + + LINK_LIBS PUBLIC + MLIRIR + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + MLIRControlFlowInterfaces +) + +add_dependencies(mlir-headers + MLIRSMTIncGen +) diff --git a/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp new file mode 100644 index 0000000000000..c28f3558a02d2 --- /dev/null +++ b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp @@ -0,0 +1,201 @@ +//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SMT/IR/SMTAttributes.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Format.h" + +using namespace mlir; +using namespace mlir::smt; + +//===----------------------------------------------------------------------===// +// BitVectorAttr +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace smt { +namespace detail { +struct BitVectorAttrStorage : public mlir::AttributeStorage { + using KeyTy = APInt; + BitVectorAttrStorage(APInt value) : value(std::move(value)) {} + + KeyTy getAsKey() const { return value; } + + // NOTE: the implementation of this operator is the reason we need to define + // the storage manually. The auto-generated version would just do the direct + // equality check of the APInt, but that asserts the bitwidth of both to be + // the same, leading to a crash. This implementation, therefore, checks for + // matching bit-width beforehand. + bool operator==(const KeyTy &key) const { + return (value.getBitWidth() == key.getBitWidth() && value == key); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static BitVectorAttrStorage * + construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) { + return new (allocator.allocate()) + BitVectorAttrStorage(std::move(key)); + } + + APInt value; +}; +} // namespace detail +} // namespace smt +} // namespace mlir + +APInt BitVectorAttr::getValue() const { return getImpl()->value; } + +LogicalResult BitVectorAttr::verify( + function_ref emitError, + APInt value) { // NOLINT(performance-unnecessary-value-param) + if (value.getBitWidth() < 1) + return emitError() << "bit-width must be at least 1, but got " + << value.getBitWidth(); + return success(); +} + +std::string BitVectorAttr::getValueAsString(bool prefix) const { + unsigned width = getValue().getBitWidth(); + SmallVector toPrint; + StringRef pref = prefix ? "#" : ""; + if (width % 4 == 0) { + getValue().toString(toPrint, 16, false, false, false); + // APInt's 'toString' omits leading zeros. However, those are critical here + // because they determine the bit-width of the bit-vector. + SmallVector leadingZeros(width / 4 - toPrint.size(), '0'); + return (pref + "x" + Twine(leadingZeros) + toPrint).str(); + } + + getValue().toString(toPrint, 2, false, false, false); + // APInt's 'toString' omits leading zeros + SmallVector leadingZeros(width - toPrint.size(), '0'); + return (pref + "b" + Twine(leadingZeros) + toPrint).str(); +} + +/// Parse an SMT-LIB formatted bit-vector string. +static FailureOr +parseBitVectorString(function_ref emitError, + StringRef value) { + if (value[0] != '#') + return emitError() << "expected '#'"; + + if (value.size() < 3) + return emitError() << "expected at least one digit"; + + if (value[1] == 'b') + return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), + 2); + + if (value[1] == 'x') + return APInt((value.size() - 2) * 4, + std::string(value.begin() + 2, value.end()), 16); + + return emitError() << "expected either 'b' or 'x'"; +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { + auto maybeValue = parseBitVectorString(nullptr, value); + + assert(succeeded(maybeValue) && "string must have SMT-LIB format"); + return Base::get(context, *maybeValue); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, StringRef value) { + auto maybeValue = parseBitVectorString(emitError, value); + if (failed(maybeValue)) + return {}; + + return Base::getChecked(emitError, context, *maybeValue); +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, + unsigned width) { + return Base::get(context, APInt(width, value)); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, uint64_t value, + unsigned width) { + if (width < 64 && value >= (UINT64_C(1) << width)) { + emitError() << "value does not fit in a bit-vector of desired width"; + return {}; + } + return Base::getChecked(emitError, context, APInt(width, value)); +} + +Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { + llvm::SMLoc loc = odsParser.getCurrentLocation(); + + APInt val; + if (odsParser.parseLess() || odsParser.parseInteger(val) || + odsParser.parseGreater()) + return {}; + + // Requires the use of `quantified()` in operation assembly formats. + if (!odsType || !llvm::isa(odsType)) { + odsParser.emitError(loc) << "explicit bit-vector type required"; + return {}; + } + + unsigned width = llvm::cast(odsType).getWidth(); + + if (width > val.getBitWidth()) { + // sext is always safe here, even for unsigned values, because the + // parseOptionalInteger method will return something with a zero in the + // top bits if it is a positive number. + val = val.sext(width); + } else if (width < val.getBitWidth()) { + // The parser can return an unnecessarily wide result. + // This isn't a problem, but truncating off bits is bad. + unsigned neededBits = + val.isNegative() ? val.getSignificantBits() : val.getActiveBits(); + if (width < neededBits) { + odsParser.emitError(loc) + << "integer value out of range for given bit-vector type " << odsType; + return {}; + } + val = val.trunc(width); + } + + return BitVectorAttr::get(odsParser.getContext(), val); +} + +void BitVectorAttr::print(AsmPrinter &odsPrinter) const { + // This printer only works for the extended format where the MLIR + // infrastructure prints the type for us. This means, the attribute should + // never be used without `quantified` in an assembly format. + odsPrinter << "<" << getValue() << ">"; +} + +Type BitVectorAttr::getType() const { + return BitVectorType::get(getContext(), getValue().getBitWidth()); +} + +//===----------------------------------------------------------------------===// +// ODS Boilerplate +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" + +void SMTDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp new file mode 100644 index 0000000000000..66eed861b2bb7 --- /dev/null +++ b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp @@ -0,0 +1,47 @@ +//===- SMTDialect.cpp - SMT dialect implementation ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SMT/IR/SMTAttributes.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" + +using namespace mlir; +using namespace smt; + +void SMTDialect::initialize() { + registerAttributes(); + registerTypes(); + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/SMT/IR/SMT.cpp.inc" + >(); +} + +Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // BitVectorType constants can materialize into smt.bv.constant + if (auto bvType = dyn_cast(type)) { + if (auto attrValue = dyn_cast(value)) { + assert(bvType == attrValue.getType() && + "attribute and desired result types have to match"); + return builder.create(loc, attrValue); + } + } + + // BoolType constants can materialize into smt.constant + if (auto boolType = dyn_cast(type)) { + if (auto attrValue = dyn_cast(value)) + return builder.create(loc, attrValue); + } + + return nullptr; +} + +#include "mlir/Dialect/SMT/IR/SMTDialect.cpp.inc" +#include "mlir/Dialect/SMT/IR/SMTEnums.cpp.inc" diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp new file mode 100644 index 0000000000000..8977a3abc125d --- /dev/null +++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp @@ -0,0 +1,472 @@ +//===- SMTOps.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/APSInt.h" + +using namespace mlir; +using namespace smt; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BVConstantOp +//===----------------------------------------------------------------------===// + +LogicalResult BVConstantOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.push_back( + properties.as()->getValue().getType()); + return success(); +} + +void BVConstantOp::getAsmResultNames( + function_ref setNameFn) { + SmallVector specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "c" << getValue().getValue() << "_bv" + << getValue().getValue().getBitWidth(); + setNameFn(getResult(), specialName.str()); +} + +OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + +//===----------------------------------------------------------------------===// +// DeclareFunOp +//===----------------------------------------------------------------------===// + +void DeclareFunOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : ""); +} + +//===----------------------------------------------------------------------===// +// SolverOp +//===----------------------------------------------------------------------===// + +LogicalResult SolverOp::verifyRegions() { + if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes()) + return emitOpError() << "types of yielded values must match return values"; + if (getBody()->getArgumentTypes() != getInputs().getTypes()) + return emitOpError() + << "block argument types must match the types of the 'inputs'"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// CheckOp +//===----------------------------------------------------------------------===// + +LogicalResult CheckOp::verifyRegions() { + if (getSatRegion().front().getTerminator()->getOperands().getTypes() != + getResultTypes()) + return emitOpError() << "types of yielded values in 'sat' region must " + "match return values"; + if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() != + getResultTypes()) + return emitOpError() << "types of yielded values in 'unknown' region must " + "match return values"; + if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() != + getResultTypes()) + return emitOpError() << "types of yielded values in 'unsat' region must " + "match return values"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// EqOp +//===----------------------------------------------------------------------===// + +static LogicalResult +parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, + OperationState &result) { + SmallVector inputs; + SMLoc loc = parser.getCurrentLocation(); + Type type; + + if (parser.parseOperandList(inputs) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(type)) + return failure(); + + result.addTypes(BoolType::get(parser.getContext())); + if (parser.resolveOperands(inputs, SmallVector(inputs.size(), type), + loc, result.operands)) + return failure(); + + return success(); +} + +ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) { + return parseSameOperandTypeVariadicToBoolOp(parser, result); +} + +void EqOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInputs(); + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : " << getInputs().front().getType(); +} + +LogicalResult EqOp::verify() { + if (getInputs().size() < 2) + return emitOpError() << "'inputs' must have at least size 2, but got " + << getInputs().size(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// DistinctOp +//===----------------------------------------------------------------------===// + +ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) { + return parseSameOperandTypeVariadicToBoolOp(parser, result); +} + +void DistinctOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInputs(); + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : " << getInputs().front().getType(); +} + +LogicalResult DistinctOp::verify() { + if (getInputs().size() < 2) + return emitOpError() << "'inputs' must have at least size 2, but got " + << getInputs().size(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ExtractOp +//===----------------------------------------------------------------------===// + +LogicalResult ExtractOp::verify() { + unsigned rangeWidth = getType().getWidth(); + unsigned inputWidth = cast(getInput().getType()).getWidth(); + if (getLowBit() + rangeWidth > inputWidth) + return emitOpError("range to be extracted is too big, expected range " + "starting at index ") + << getLowBit() << " of length " << rangeWidth + << " requires input width of at least " << (getLowBit() + rangeWidth) + << ", but the input width is only " << inputWidth; + return success(); +} + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +LogicalResult ConcatOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(BitVectorType::get( + context, cast(operands[0].getType()).getWidth() + + cast(operands[1].getType()).getWidth())); + return success(); +} + +//===----------------------------------------------------------------------===// +// RepeatOp +//===----------------------------------------------------------------------===// + +LogicalResult RepeatOp::verify() { + unsigned inputWidth = cast(getInput().getType()).getWidth(); + unsigned resultWidth = getType().getWidth(); + if (resultWidth % inputWidth != 0) + return emitOpError() << "result bit-vector width must be a multiple of the " + "input bit-vector width"; + + return success(); +} + +unsigned RepeatOp::getCount() { + unsigned inputWidth = cast(getInput().getType()).getWidth(); + unsigned resultWidth = getType().getWidth(); + return resultWidth / inputWidth; +} + +void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count, + Value input) { + unsigned inputWidth = cast(input.getType()).getWidth(); + Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count); + build(builder, state, resultTy, input); +} + +ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + llvm::SMLoc countLoc = parser.getCurrentLocation(); + + APInt count; + if (parser.parseInteger(count) || parser.parseKeyword("times")) + return failure(); + + if (count.isNonPositive()) + return parser.emitError(countLoc) << "integer must be positive"; + + llvm::SMLoc inputLoc = parser.getCurrentLocation(); + if (parser.parseOperand(input) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(inputType)) + return failure(); + + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + + auto bvInputTy = dyn_cast(inputType); + if (!bvInputTy) + return parser.emitError(inputLoc) << "input must have bit-vector type"; + + // Make sure no assertions can trigger and no silent overflows can happen + // Bit-width is stored as 'int64_t' parameter in 'BitVectorType' + const unsigned maxBw = 63; + if (count.getActiveBits() > maxBw) + return parser.emitError(countLoc) + << "integer must fit into " << maxBw << " bits"; + + // Store multiplication in an APInt twice the size to not have any overflow + // and check if it can be truncated to 'maxBw' bits without cutting of + // important bits. + APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw); + if (resultBw.getActiveBits() > maxBw) + return parser.emitError(countLoc) + << "result bit-width (provided integer times bit-width of the input " + "type) must fit into " + << maxBw << " bits"; + + Type resultTy = + BitVectorType::get(parser.getContext(), resultBw.getZExtValue()); + result.addTypes(resultTy); + return success(); +} + +void RepeatOp::print(OpAsmPrinter &printer) { + printer << " " << getCount() << " times " << getInput(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getInput().getType(); +} + +//===----------------------------------------------------------------------===// +// BoolConstantOp +//===----------------------------------------------------------------------===// + +void BoolConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), getValue() ? "true" : "false"); +} + +OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + +//===----------------------------------------------------------------------===// +// IntConstantOp +//===----------------------------------------------------------------------===// + +void IntConstantOp::getAsmResultNames( + function_ref setNameFn) { + SmallVector specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "c" << getValue(); + setNameFn(getResult(), specialName.str()); +} + +OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + +void IntConstantOp::print(OpAsmPrinter &p) { + p << " " << getValue(); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); +} + +ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) { + APInt value; + if (parser.parseInteger(value)) + return failure(); + + result.getOrAddProperties().setValue( + IntegerAttr::get(parser.getContext(), APSInt(value))); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + result.addTypes(smt::IntType::get(parser.getContext())); + return success(); +} + +//===----------------------------------------------------------------------===// +// ForallOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyQuantifierRegions(QuantifierOp op) { + if (op.getBoundVarNames() && + op.getBody().getNumArguments() != op.getBoundVarNames()->size()) + return op.emitOpError( + "number of bound variable names must match number of block arguments"); + if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType)) + return op.emitOpError() + << "bound variables must by any non-function SMT value"; + + if (op.getBody().front().getTerminator()->getNumOperands() != 1) + return op.emitOpError("must have exactly one yielded value"); + if (!isa( + op.getBody().front().getTerminator()->getOperand(0).getType())) + return op.emitOpError("yielded value must be of '!smt.bool' type"); + + for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) { + unsigned i = regionWithIndex.index(); + Region ®ion = regionWithIndex.value(); + + if (op.getBody().getArgumentTypes() != region.getArgumentTypes()) + return op.emitOpError() + << "block argument number and types of the 'body' " + "and 'patterns' region #" + << i << " must match"; + if (region.front().getTerminator()->getNumOperands() < 1) + return op.emitOpError() << "'patterns' region #" << i + << " must have at least one yielded value"; + + // All operations in the 'patterns' region must be SMT operations. + auto result = region.walk([&](Operation *childOp) { + if (!isa(childOp->getDialect())) { + auto diag = op.emitOpError() + << "the 'patterns' region #" << i + << " may only contain SMT dialect operations"; + diag.attachNote(childOp->getLoc()) << "first non-SMT operation here"; + return WalkResult::interrupt(); + } + + // There may be no quantifier (or other variable binding) operations in + // the 'patterns' region. + if (isa(childOp)) { + auto diag = op.emitOpError() << "the 'patterns' region #" << i + << " must not contain " + "any variable binding operations"; + diag.attachNote(childOp->getLoc()) << "first violating operation here"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + } + + return success(); +} + +template +static void buildQuantifier( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + odsState.addTypes(BoolType::get(odsBuilder.getContext())); + if (weight != 0) + odsState.getOrAddProperties().weight = + odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight); + if (noPattern) + odsState.getOrAddProperties().noPattern = + odsBuilder.getUnitAttr(); + if (boundVarNames.has_value()) { + SmallVector boundVarNamesList; + for (StringRef str : *boundVarNames) + boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str)); + odsState.getOrAddProperties().boundVarNames = + odsBuilder.getArrayAttr(boundVarNamesList); + } + { + OpBuilder::InsertionGuard guard(odsBuilder); + Region *region = odsState.addRegion(); + Block *block = odsBuilder.createBlock(region); + block->addArguments( + boundVarTypes, + SmallVector(boundVarTypes.size(), odsState.location)); + Value returnVal = + bodyBuilder(odsBuilder, odsState.location, block->getArguments()); + odsBuilder.create(odsState.location, returnVal); + } + if (patternBuilder) { + Region *region = odsState.addRegion(); + OpBuilder::InsertionGuard guard(odsBuilder); + Block *block = odsBuilder.createBlock(region); + block->addArguments( + boundVarTypes, + SmallVector(boundVarTypes.size(), odsState.location)); + ValueRange returnVals = + patternBuilder(odsBuilder, odsState.location, block->getArguments()); + odsBuilder.create(odsState.location, returnVals); + } +} + +LogicalResult ForallOp::verify() { + if (!getPatterns().empty() && getNoPattern()) + return emitOpError() << "patterns and the no_pattern attribute must not be " + "specified at the same time"; + + return success(); +} + +LogicalResult ForallOp::verifyRegions() { + return verifyQuantifierRegions(*this); +} + +void ForallOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + +//===----------------------------------------------------------------------===// +// ExistsOp +//===----------------------------------------------------------------------===// + +LogicalResult ExistsOp::verify() { + if (!getPatterns().empty() && getNoPattern()) + return emitOpError() << "patterns and the no_pattern attribute must not be " + "specified at the same time"; + + return success(); +} + +LogicalResult ExistsOp::verifyRegions() { + return verifyQuantifierRegions(*this); +} + +void ExistsOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/SMT/IR/SMT.cpp.inc" diff --git a/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp b/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp new file mode 100644 index 0000000000000..6188719bb1ab5 --- /dev/null +++ b/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp @@ -0,0 +1,92 @@ +//===- SMTTypes.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SMT/IR/SMTTypes.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace smt; +using namespace mlir; + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/SMT/IR/SMTTypes.cpp.inc" + +void SMTDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/SMT/IR/SMTTypes.cpp.inc" + >(); +} + +bool smt::isAnyNonFuncSMTValueType(Type type) { + return isAnySMTValueType(type) && !isa(type); +} + +bool smt::isAnySMTValueType(Type type) { + return isa(type); +} + +//===----------------------------------------------------------------------===// +// BitVectorType +//===----------------------------------------------------------------------===// + +LogicalResult +BitVectorType::verify(function_ref emitError, + int64_t width) { + if (width <= 0U) + return emitError() << "bit-vector must have at least a width of one"; + return success(); +} + +//===----------------------------------------------------------------------===// +// ArrayType +//===----------------------------------------------------------------------===// + +LogicalResult ArrayType::verify(function_ref emitError, + Type domainType, Type rangeType) { + if (!isAnySMTValueType(domainType)) + return emitError() << "domain must be any SMT value type"; + if (!isAnySMTValueType(rangeType)) + return emitError() << "range must be any SMT value type"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SMTFuncType +//===----------------------------------------------------------------------===// + +LogicalResult SMTFuncType::verify(function_ref emitError, + ArrayRef domainTypes, Type rangeType) { + if (domainTypes.empty()) + return emitError() << "domain must not be empty"; + if (!llvm::all_of(domainTypes, isAnyNonFuncSMTValueType)) + return emitError() << "domain types must be any non-function SMT type"; + if (!isAnyNonFuncSMTValueType(rangeType)) + return emitError() << "range type must be any non-function SMT type"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SortType +//===----------------------------------------------------------------------===// + +LogicalResult SortType::verify(function_ref emitError, + StringAttr identifier, + ArrayRef sortParams) { + if (!llvm::all_of(sortParams, isAnyNonFuncSMTValueType)) + return emitError() + << "sort parameter types must be any non-function SMT type"; + + return success(); +} diff --git a/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir b/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir new file mode 100644 index 0000000000000..a1cf033a461c5 --- /dev/null +++ b/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt %s --convert-arith-to-smt | FileCheck %s + +// CHECK-LABEL: func @test +// CHECK-SAME: ([[A0:%.+]]: !smt.bv<32>, [[A1:%.+]]: !smt.bv<32>, [[A2:%.+]]: !smt.bv<32>, [[A3:%.+]]: !smt.bv<32>, [[A4:%.+]]: !smt.bv<1>, [[ARG5:%.+]]: !smt.bv<4>) +func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt.bv<32>, %a4: !smt.bv<1>, %a5: !smt.bv<4>) { + %arg0 = builtin.unrealized_conversion_cast %a0 : !smt.bv<32> to i32 + %arg1 = builtin.unrealized_conversion_cast %a1 : !smt.bv<32> to i32 + %arg2 = builtin.unrealized_conversion_cast %a2 : !smt.bv<32> to i32 + %arg3 = builtin.unrealized_conversion_cast %a3 : !smt.bv<32> to i32 + %arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1 + %arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4 + + // CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> + %0 = arith.divsi %arg0, %arg1 : i32 + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> + %1 = arith.divui %arg0, %arg1 : i32 + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> + %2 = arith.remsi %arg0, %arg1 : i32 + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> + %3 = arith.remui %arg0, %arg1 : i32 + + // CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.bv.add [[A0]], [[NEG]] : !smt.bv<32> + %7 = arith.subi %arg0, %arg1 : i32 + + // CHECK-NEXT: [[A5:%.+]] = smt.bv.add [[A0]], [[A1]] : !smt.bv<32> + %8 = arith.addi %arg0, %arg1 : i32 + // CHECK-NEXT: [[B1:%.+]] = smt.bv.mul [[A0]], [[A1]] : !smt.bv<32> + %9 = arith.muli %arg0, %arg1 : i32 + // CHECK-NEXT: [[C1:%.+]] = smt.bv.and [[A0]], [[A1]] : !smt.bv<32> + %10 = arith.andi %arg0, %arg1 : i32 + // CHECK-NEXT: [[D1:%.+]] = smt.bv.or [[A0]], [[A1]] : !smt.bv<32> + %11 = arith.ori %arg0, %arg1 : i32 + // CHECK-NEXT: [[E1:%.+]] = smt.bv.xor [[A0]], [[A1]] : !smt.bv<32> + %12 = arith.xori %arg0, %arg1 : i32 + + // CHECK-NEXT: smt.eq [[A0]], [[A1]] : !smt.bv<32> + %14 = arith.cmpi eq, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.distinct [[A0]], [[A1]] : !smt.bv<32> + %15 = arith.cmpi ne, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp sle [[A0]], [[A1]] : !smt.bv<32> + %20 = arith.cmpi sle, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp slt [[A0]], [[A1]] : !smt.bv<32> + %21 = arith.cmpi slt, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp ule [[A0]], [[A1]] : !smt.bv<32> + %22 = arith.cmpi ule, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp ult [[A0]], [[A1]] : !smt.bv<32> + %23 = arith.cmpi ult, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp sge [[A0]], [[A1]] : !smt.bv<32> + %24 = arith.cmpi sge, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp sgt [[A0]], [[A1]] : !smt.bv<32> + %25 = arith.cmpi sgt, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp uge [[A0]], [[A1]] : !smt.bv<32> + %26 = arith.cmpi uge, %arg0, %arg1 : i32 + // CHECK-NEXT: smt.bv.cmp ugt [[A0]], [[A1]] : !smt.bv<32> + %27 = arith.cmpi ugt, %arg0, %arg1 : i32 + + // CHECK-NEXT: %{{.*}} = smt.bv.shl [[A0]], [[A1]] : !smt.bv<32> + %32 = arith.shli %arg0, %arg1 : i32 + // CHECK-NEXT: %{{.*}} = smt.bv.ashr [[A0]], [[A1]] : !smt.bv<32> + %33 = arith.shrsi %arg0, %arg1 : i32 + // CHECK-NEXT: %{{.*}} = smt.bv.lshr [[A0]], [[A1]] : !smt.bv<32> + %34 = arith.shrui %arg0, %arg1 : i32 + + // The arith.cmpi folder is called before the conversion patterns and produces + // a `arith.constant` operation. + // CHECK-NEXT: smt.bv.constant #smt.bv<-1> : !smt.bv<1> + %35 = arith.cmpi eq, %arg0, %arg0 : i32 + + return +} diff --git a/mlir/test/Dialect/SMT/array-errors.mlir b/mlir/test/Dialect/SMT/array-errors.mlir new file mode 100644 index 0000000000000..4e90948eed848 --- /dev/null +++ b/mlir/test/Dialect/SMT/array-errors.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +// expected-error @below {{domain must be any SMT value type}} +func.func @array_domain_no_smt_type(%arg0: !smt.array<[i32 -> !smt.bool]>) { + return +} + +// ----- + +// expected-error @below {{range must be any SMT value type}} +func.func @array_range_no_smt_type(%arg0: !smt.array<[!smt.bool -> i32]>) { + return +} diff --git a/mlir/test/Dialect/SMT/array.mlir b/mlir/test/Dialect/SMT/array.mlir new file mode 100644 index 0000000000000..89cb45c5e878a --- /dev/null +++ b/mlir/test/Dialect/SMT/array.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @arrayOperations +// CHECK-SAME: ([[A0:%.+]]: !smt.bool) +func.func @arrayOperations(%arg0: !smt.bool) { + // CHECK-NEXT: [[V0:%.+]] = smt.array.broadcast [[A0]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + %0 = smt.array.broadcast %arg0 {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + // CHECK-NEXT: [[V1:%.+]] = smt.array.select [[V0]][[[A0]]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + %1 = smt.array.select %0[%arg0] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + // CHECK-NEXT: [[V2:%.+]] = smt.array.store [[V0]][[[A0]]], [[A0]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + %2 = smt.array.store %0[%arg0], %arg0 {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]> + + return +} diff --git a/mlir/test/Dialect/SMT/basic.mlir b/mlir/test/Dialect/SMT/basic.mlir new file mode 100644 index 0000000000000..a4975d66e9769 --- /dev/null +++ b/mlir/test/Dialect/SMT/basic.mlir @@ -0,0 +1,200 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @types +// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int, %{{.*}}: !smt.sort<"uninterpreted_sort">, %{{.*}}: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %{{.*}}: !smt.func<(!smt.bool, !smt.bool) !smt.bool>) +func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int, %arg3: !smt.sort<"uninterpreted_sort">, %arg4: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %arg5: !smt.func<(!smt.bool, !smt.bool) !smt.bool>) { + return +} + +func.func @core(%in: i8) { + // CHECK: %a = smt.declare_fun "a" {smt.some_attr} : !smt.bool + %a = smt.declare_fun "a" {smt.some_attr} : !smt.bool + // CHECK: smt.declare_fun {smt.some_attr} : !smt.bv<32> + %b = smt.declare_fun {smt.some_attr} : !smt.bv<32> + // CHECK: smt.declare_fun {smt.some_attr} : !smt.int + %c = smt.declare_fun {smt.some_attr} : !smt.int + // CHECK: smt.declare_fun {smt.some_attr} : !smt.sort<"uninterpreted_sort"> + %d = smt.declare_fun {smt.some_attr} : !smt.sort<"uninterpreted_sort"> + // CHECK: smt.declare_fun {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + %e = smt.declare_fun {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + + // CHECK: smt.constant true {smt.some_attr} + %true = smt.constant true {smt.some_attr} + // CHECK: smt.constant false {smt.some_attr} + %false = smt.constant false {smt.some_attr} + + // CHECK: smt.assert %a {smt.some_attr} + smt.assert %a {smt.some_attr} + + // CHECK: smt.reset {smt.some_attr} + smt.reset {smt.some_attr} + + // CHECK: smt.push 1 {smt.some_attr} + smt.push 1 {smt.some_attr} + + // CHECK: smt.pop 1 {smt.some_attr} + smt.pop 1 {smt.some_attr} + + // CHECK: %{{.*}} = smt.solver(%{{.*}}) {smt.some_attr} : (i8) -> (i8, i32) { + // CHECK: ^bb0(%{{.*}}: i8) + // CHECK: %{{.*}} = smt.check {smt.some_attr} sat { + // CHECK: smt.yield %{{.*}} : i32 + // CHECK: } unknown { + // CHECK: smt.yield %{{.*}} : i32 + // CHECK: } unsat { + // CHECK: smt.yield %{{.*}} : i32 + // CHECK: } -> i32 + // CHECK: smt.yield %{{.*}}, %{{.*}} : i8, i32 + // CHECK: } + %0:2 = smt.solver(%in) {smt.some_attr} : (i8) -> (i8, i32) { + ^bb0(%arg0: i8): + %1 = smt.check {smt.some_attr} sat { + %c1_i32 = arith.constant 1 : i32 + smt.yield %c1_i32 : i32 + } unknown { + %c0_i32 = arith.constant 0 : i32 + smt.yield %c0_i32 : i32 + } unsat { + %c-1_i32 = arith.constant -1 : i32 + smt.yield %c-1_i32 : i32 + } -> i32 + smt.yield %arg0, %1 : i8, i32 + } + + // CHECK: smt.solver() : () -> () { + // CHECK-NEXT: } + smt.solver() : () -> () { } + + // CHECK: smt.solver() : () -> () { + // CHECK-NEXT: smt.set_logic "AUFLIA" + // CHECK-NEXT: } + smt.solver() : () -> () { + smt.set_logic "AUFLIA" + } + + // CHECK: smt.check sat { + // CHECK-NEXT: } unknown { + // CHECK-NEXT: } unsat { + // CHECK-NEXT: } + smt.check sat { } unknown { } unsat { } + + // CHECK: %{{.*}} = smt.eq %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32> + %1 = smt.eq %b, %b {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.distinct %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32> + %2 = smt.distinct %b, %b {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.eq %{{.*}}, %{{.*}}, %{{.*}} : !smt.bool + %3 = smt.eq %a, %a, %a : !smt.bool + // CHECK: %{{.*}} = smt.distinct %{{.*}}, %{{.*}}, %{{.*}} : !smt.bool + %4 = smt.distinct %a, %a, %a : !smt.bool + + // CHECK: %{{.*}} = smt.ite %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32> + %5 = smt.ite %a, %b, %b {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.not %{{.*}} {smt.some_attr} + %6 = smt.not %a {smt.some_attr} + // CHECK: %{{.*}} = smt.and %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr} + %7 = smt.and %a, %a, %a {smt.some_attr} + // CHECK: %{{.*}} = smt.or %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr} + %8 = smt.or %a, %a, %a {smt.some_attr} + // CHECK: %{{.*}} = smt.xor %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr} + %9 = smt.xor %a, %a, %a {smt.some_attr} + // CHECK: %{{.*}} = smt.implies %{{.*}}, %{{.*}} {smt.some_attr} + %10 = smt.implies %a, %a {smt.some_attr} + + // CHECK: smt.apply_func %{{.*}}(%{{.*}}, %{{.*}}) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + %11 = smt.apply_func %e(%c, %a) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + + return +} + +// CHECK-LABEL: func @quantifiers +func.func @quantifiers() { + // CHECK-NEXT: smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } patterns { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: }, { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %0 = smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %1 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %1 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + }, { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + + // CHECK-NEXT: smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %1 = smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %2 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %2 : !smt.bool + } + + // CHECK-NEXT: smt.forall { + // CHECK-NEXT: smt.constant + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %2 = smt.forall { + %3 = smt.constant true + smt.yield %3 : !smt.bool + } + + // CHECK-NEXT: smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } patterns { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: }, { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %3 = smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %4 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %4 : !smt.bool {smt.some_attr} + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + }, { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + + // CHECK-NEXT: smt.exists no_pattern attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %4 = smt.exists no_pattern attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %5 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %5 : !smt.bool {smt.some_attr} + } + + // CHECK-NEXT: smt.exists [] { + // CHECK-NEXT: smt.constant + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %5 = smt.exists [] { + %6 = smt.constant true + smt.yield %6 : !smt.bool + } + + return +} diff --git a/mlir/test/Dialect/SMT/bitvector-errors.mlir b/mlir/test/Dialect/SMT/bitvector-errors.mlir new file mode 100644 index 0000000000000..58226f4d55f62 --- /dev/null +++ b/mlir/test/Dialect/SMT/bitvector-errors.mlir @@ -0,0 +1,112 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +// expected-error @below {{bit-vector must have at least a width of one}} +func.func @at_least_size_one(%arg0: !smt.bv<0>) { + return +} + +// ----- + +// expected-error @below {{bit-vector must have at least a width of one}} +func.func @positive_width(%arg0: !smt.bv<-1>) { + return +} + +// ----- + +func.func @attr_type_and_return_type_match() { + // expected-error @below {{inferred type(s) '!smt.bv<1>' are incompatible with return type(s) of operation '!smt.bv<32>'}} + // expected-error @below {{failed to infer returned types}} + %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<0> : !smt.bv<1>}> : () -> !smt.bv<32> + return +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{explicit bit-vector type required}} + smt.bv.constant #smt.bv<5> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<32> : !smt.bv<2> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<-4> : !smt.bv<2> +} + +// ----- + +func.func @extraction(%arg0: !smt.bv<32>) { + // expected-error @below {{range to be extracted is too big, expected range starting at index 20 of length 16 requires input width of at least 36, but the input width is only 32}} + smt.bv.extract %arg0 from 20 : (!smt.bv<32>) -> !smt.bv<16> + return +} + +// ----- + +func.func @concat(%arg0: !smt.bv<32>) { + // expected-error @below {{inferred type(s) '!smt.bv<64>' are incompatible with return type(s) of operation '!smt.bv<33>'}} + // expected-error @below {{failed to infer returned types}} + "smt.bv.concat"(%arg0, %arg0) {} : (!smt.bv<32>, !smt.bv<32>) -> !smt.bv<33> + return +} + +// ----- + +func.func @repeat_result_type_no_multiple_of_input_type(%arg0: !smt.bv<32>) { + // expected-error @below {{result bit-vector width must be a multiple of the input bit-vector width}} + "smt.bv.repeat"(%arg0) : (!smt.bv<32>) -> !smt.bv<65> + return +} + +// ----- + +func.func @repeat_negative_count(%arg0: !smt.bv<32>) { + // expected-error @below {{integer must be positive}} + smt.bv.repeat -2 times %arg0 : !smt.bv<32> + return +} + +// ----- + +// The parser has to extract the bit-width of the input and thus we need to +// test that this is handled correctly in the parser, we cannot just rely on the +// verifier. +func.func @repeat_wrong_input_type(%arg0: !smt.bool) { + // expected-error @below {{input must have bit-vector type}} + smt.bv.repeat 2 times %arg0 : !smt.bool + return +} + +// ----- + +func.func @repeat_count_too_large(%arg0: !smt.bv<32>) { + // expected-error @below {{integer must fit into 63 bits}} + smt.bv.repeat 18446744073709551617 times %arg0 : !smt.bv<32> + return +} + +// ----- + +func.func @repeat_result_type_bitwidth_too_large(%arg0: !smt.bv<9223372036854775807>) { + // expected-error @below {{result bit-width (provided integer times bit-width of the input type) must fit into 63 bits}} + smt.bv.repeat 2 times %arg0 : !smt.bv<9223372036854775807> + return +} + +// ----- + +func.func @invalid_bv2int_signedness() { + %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> + // expected-error @below {{expected ':'}} + %bv2int = smt.bv2int %c5_bv32 unsigned : !smt.bv<32> + return +} diff --git a/mlir/test/Dialect/SMT/bitvectors.mlir b/mlir/test/Dialect/SMT/bitvectors.mlir new file mode 100644 index 0000000000000..2482f55b5ed31 --- /dev/null +++ b/mlir/test/Dialect/SMT/bitvectors.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @bitvectors +func.func @bitvectors() { + // CHECK: %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + // CHECK: %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> {smt.some_attr} + %c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr} + // CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + %c-1_bv1_neg = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + %c-1_bv1_pos = smt.bv.constant #smt.bv<1> : !smt.bv<1> + + // CHECK: [[C0:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + %c = smt.bv.constant #smt.bv<0> : !smt.bv<32> + + // CHECK: %{{.*}} = smt.bv.neg [[C0]] {smt.some_attr} : !smt.bv<32> + %0 = smt.bv.neg %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.add [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %1 = smt.bv.add %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.mul [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %3 = smt.bv.mul %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.urem [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %4 = smt.bv.urem %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.srem [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %5 = smt.bv.srem %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.smod [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %7 = smt.bv.smod %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.shl [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %8 = smt.bv.shl %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.lshr [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %9 = smt.bv.lshr %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.ashr [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %10 = smt.bv.ashr %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.udiv [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %11 = smt.bv.udiv %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.sdiv [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %12 = smt.bv.sdiv %c, %c {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.bv.not [[C0]] {smt.some_attr} : !smt.bv<32> + %13 = smt.bv.not %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.and [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %14 = smt.bv.and %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.or [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %15 = smt.bv.or %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.xor [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %16 = smt.bv.xor %c, %c {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.bv.cmp slt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %17 = smt.bv.cmp slt %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp sle [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %18 = smt.bv.cmp sle %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp sgt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %19 = smt.bv.cmp sgt %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp sge [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %20 = smt.bv.cmp sge %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp ult [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %21 = smt.bv.cmp ult %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp ule [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %22 = smt.bv.cmp ule %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp ugt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %23 = smt.bv.cmp ugt %c, %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.cmp uge [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32> + %24 = smt.bv.cmp uge %c, %c {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.bv.concat [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>, !smt.bv<32> + %25 = smt.bv.concat %c, %c {smt.some_attr} : !smt.bv<32>, !smt.bv<32> + // CHECK: %{{.*}} = smt.bv.extract [[C0]] from 8 {smt.some_attr} : (!smt.bv<32>) -> !smt.bv<16> + %26 = smt.bv.extract %c from 8 {smt.some_attr} : (!smt.bv<32>) -> !smt.bv<16> + // CHECK: %{{.*}} = smt.bv.repeat 2 times [[C0]] {smt.some_attr} : !smt.bv<32> + %27 = smt.bv.repeat 2 times %c {smt.some_attr} : !smt.bv<32> + + // CHECK: %{{.*}} = smt.bv2int [[C0]] {smt.some_attr} : !smt.bv<32> + %29 = smt.bv2int %c {smt.some_attr} : !smt.bv<32> + // CHECK: %{{.*}} = smt.bv2int [[C0]] signed {smt.some_attr} : !smt.bv<32> + %28 = smt.bv2int %c signed {smt.some_attr} : !smt.bv<32> + + return +} diff --git a/mlir/test/Dialect/SMT/core-errors.mlir b/mlir/test/Dialect/SMT/core-errors.mlir new file mode 100644 index 0000000000000..67bebda56b68e --- /dev/null +++ b/mlir/test/Dialect/SMT/core-errors.mlir @@ -0,0 +1,497 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +func.func @solver_isolated_from_above(%arg0: !smt.bool) { + // expected-note @below {{required by region isolation constraints}} + smt.solver() : () -> () { + // expected-error @below {{using value defined outside the region}} + smt.assert %arg0 + } + return +} + +// ----- + +func.func @no_smt_value_enters_solver(%arg0: !smt.bool) { + // expected-error @below {{operand #0 must be variadic of any non-smt type, but got '!smt.bool'}} + smt.solver(%arg0) : (!smt.bool) -> () { + ^bb0(%arg1: !smt.bool): + smt.assert %arg1 + } + return +} + +// ----- + +func.func @no_smt_value_exits_solver() { + // expected-error @below {{result #0 must be variadic of any non-smt type, but got '!smt.bool'}} + %0 = smt.solver() : () -> !smt.bool { + %a = smt.declare_fun "a" : !smt.bool + smt.yield %a : !smt.bool + } + return +} + +// ----- + +func.func @block_args_and_inputs_match() { + // expected-error @below {{block argument types must match the types of the 'inputs'}} + smt.solver() : () -> () { + ^bb0(%arg0: i32): + } + return +} + +// ----- + +func.func @solver_yield_operands_and_results_match() { + // expected-error @below {{types of yielded values must match return values}} + smt.solver() : () -> () { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } + return +} + +// ----- + +func.func @check_yield_operands_and_results_match() { + // expected-error @below {{types of yielded values in 'unsat' region must match return values}} + %0 = smt.check sat { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } unknown { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } unsat { } -> i32 + return +} + +// ----- + +func.func @check_yield_operands_and_results_match() { + // expected-error @below {{types of yielded values in 'unknown' region must match return values}} + %0 = smt.check sat { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } unknown { + } unsat { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } -> i32 + return +} + +// ----- + +func.func @check_yield_operands_and_results_match() { + // expected-error @below {{types of yielded values in 'sat' region must match return values}} + %0 = smt.check sat { + } unknown { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } unsat { + %1 = arith.constant 0 : i32 + smt.yield %1 : i32 + } -> i32 + return +} + +// ----- + +func.func @check_no_block_arguments() { + // expected-error @below {{region #0 should have no arguments}} + smt.check sat { + ^bb0(%arg0: i32): + } unknown { + } unsat { + } + return +} + +// ----- + +func.func @check_no_block_arguments() { + // expected-error @below {{region #1 should have no arguments}} + smt.check sat { + } unknown { + ^bb0(%arg0: i32): + } unsat { + } + return +} + +// ----- + +func.func @check_no_block_arguments() { + // expected-error @below {{region #2 should have no arguments}} + smt.check sat { + } unknown { + } unsat { + ^bb0(%arg0: i32): + } + return +} + +// ----- + +func.func @too_few_operands() { + // expected-error @below {{'inputs' must have at least size 2, but got 0}} + smt.eq : !smt.bool + return +} + +// ----- + +func.func @too_few_operands(%a: !smt.bool) { + // expected-error @below {{'inputs' must have at least size 2, but got 1}} + smt.distinct %a : !smt.bool + return +} + +// ----- + +func.func @ite_type_mismatch(%a: !smt.bool, %b: !smt.bv<32>) { + // expected-error @below {{failed to verify that all of {thenValue, elseValue, result} have same type}} + "smt.ite"(%a, %a, %b) {} : (!smt.bool, !smt.bool, !smt.bv<32>) -> !smt.bool + return +} + +// ----- + +func.func @forall_number_of_decl_names_must_match_num_args() { + // expected-error @below {{number of bound variable names must match number of block arguments}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.eq %arg2, %arg3 : !smt.int + smt.yield %2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_number_of_decl_names_must_match_num_args() { + // expected-error @below {{number of bound variable names must match number of block arguments}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.eq %arg2, %arg3 : !smt.int + smt.yield %2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{yielded value must be of '!smt.bool' type}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.int.add %arg2, %arg3 + smt.yield %2 : !smt.int + } + return +} + +// ----- + +func.func @forall_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{must have exactly one yielded value}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + smt.yield + } + return +} + +// ----- + +func.func @exists_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{yielded value must be of '!smt.bool' type}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.int.add %arg2, %arg3 + smt.yield %2 : !smt.int + } + return +} + +// ----- + +func.func @exists_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{must have exactly one yielded value}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + smt.yield + } + return +} + +// ----- + +func.func @exists_patterns_region_and_no_patterns_attr_are_mutually_exclusive() { + // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}} + %1 = smt.exists ["a"] no_pattern { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_and_no_patterns_attr_are_mutually_exclusive() { + // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}} + %1 = smt.forall ["a"] no_pattern { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_num_args() { + // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_num_args() { + // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_at_least_one_yielded_value() { + // expected-error @below {{'patterns' region #0 must have at least one yielded value}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @forall_patterns_region_at_least_one_yielded_value() { + // expected-error @below {{'patterns' region #0 must have at least one yielded value}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @exists_all_pattern_regions_tested() { + // expected-error @below {{'patterns' region #1 must have at least one yielded value}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + }, { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @forall_all_pattern_regions_tested() { + // expected-error @below {{'patterns' region #1 must have at least one yielded value}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + }, { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @exists_patterns_region_no_non_smt_operations() { + // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first non-SMT operation here}} + arith.constant 0 : i32 + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_no_non_smt_operations() { + // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first non-SMT operation here}} + arith.constant 0 : i32 + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_no_var_binding_operations() { + // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first violating operation here}} + smt.exists ["b"] { + ^bb0(%arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_no_var_binding_operations() { + // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first violating operation here}} + smt.forall ["b"] { + ^bb0(%arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_bound_variable_type_invalid() { + // expected-error @below {{bound variables must by any non-function SMT value}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + return +} + +// ----- + +func.func @forall_bound_variable_type_invalid() { + // expected-error @below {{bound variables must by any non-function SMT value}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + return +} + +// ----- + +// expected-error @below {{domain types must be any non-function SMT type}} +func.func @func_domain_no_smt_type(%arg0: !smt.func<(i32) !smt.bool>) { + return +} + +// ----- + +// expected-error @below {{range type must be any non-function SMT type}} +func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) i32>) { + return +} + +// ----- + +// expected-error @below {{range type must be any non-function SMT type}} +func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) !smt.func<(!smt.bool) !smt.bool>>) { + return +} + +// ----- + +func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) !smt.bool>) { + // expected-error @below {{got 0 operands and 1 types}} + smt.apply_func %arg0() : !smt.func<(!smt.bool) !smt.bool> + return +} + +// ----- + +// expected-error @below {{sort parameter types must be any non-function SMT type}} +func.func @sort_type_no_smt_type(%arg0: !smt.sort<"sortname"[i32]>) { + return +} + +// ----- + +func.func @negative_push() { + // expected-error @below {{smt.push' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} + smt.push -1 + return +} + +// ----- + +func.func @negative_pop() { + // expected-error @below {{smt.pop' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} + smt.pop -1 + return +} + +// ----- + +func.func @set_logic_outside_solver() { + // expected-error @below {{'smt.set_logic' op expects parent op 'smt.solver'}} + smt.set_logic "AUFLIA" + return +} diff --git a/mlir/test/Dialect/SMT/cse-test.mlir b/mlir/test/Dialect/SMT/cse-test.mlir new file mode 100644 index 0000000000000..ff254857f3b33 --- /dev/null +++ b/mlir/test/Dialect/SMT/cse-test.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s --cse | FileCheck %s + +func.func @declare_const_cse(%in: i8) -> (!smt.bool, !smt.bool){ + // CHECK: smt.declare_fun "a" : !smt.bool + %a = smt.declare_fun "a" : !smt.bool + // CHECK-NEXT: smt.declare_fun "a" : !smt.bool + %b = smt.declare_fun "a" : !smt.bool + // CHECK-NEXT: return + %c = smt.declare_fun "a" : !smt.bool + + return %a, %b : !smt.bool, !smt.bool +} diff --git a/mlir/test/Dialect/SMT/integers.mlir b/mlir/test/Dialect/SMT/integers.mlir new file mode 100644 index 0000000000000..f5133c8c72b5d --- /dev/null +++ b/mlir/test/Dialect/SMT/integers.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @integer_operations +func.func @integer_operations() { + // CHECK-NEXT: [[V0:%.+]] = smt.int.constant -123 {smt.some_attr} + %0 = smt.int.constant -123 {smt.some_attr} + // CHECK-NEXT: %c184467440737095516152 = smt.int.constant 184467440737095516152 {smt.some_attr} + %1 = smt.int.constant 184467440737095516152 {smt.some_attr} + + + // CHECK-NEXT: smt.int.add [[V0]], [[V0]], [[V0]] {smt.some_attr} + %2 = smt.int.add %0, %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.mul [[V0]], [[V0]], [[V0]] {smt.some_attr} + %3 = smt.int.mul %0, %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.sub [[V0]], [[V0]] {smt.some_attr} + %4 = smt.int.sub %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.div [[V0]], [[V0]] {smt.some_attr} + %5 = smt.int.div %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.mod [[V0]], [[V0]] {smt.some_attr} + %6 = smt.int.mod %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.abs [[V0]] {smt.some_attr} + %7 = smt.int.abs %0 {smt.some_attr} + + // CHECK-NEXT: smt.int.cmp le [[V0]], [[V0]] {smt.some_attr} + %9 = smt.int.cmp le %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.cmp lt [[V0]], [[V0]] {smt.some_attr} + %10 = smt.int.cmp lt %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.cmp ge [[V0]], [[V0]] {smt.some_attr} + %11 = smt.int.cmp ge %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int.cmp gt [[V0]], [[V0]] {smt.some_attr} + %12 = smt.int.cmp gt %0, %0 {smt.some_attr} + // CHECK-NEXT: smt.int2bv [[V0]] {smt.some_attr} : !smt.bv<4> + %13 = smt.int2bv %0 {smt.some_attr} : !smt.bv<4> + + return +}