diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 2d3ed60a35fd9..7d59add3d37c2 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -161,6 +161,17 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { The coefficient and polynomial modulus parameters are optional, and the coefficient modulus is only allowed if the coefficient type is integral. + + The coefficient modulus, if specified, should be positive and not larger + than `2 ** width(coefficientType)`. + + If the coefficient modulus is not specified, the handling of coefficients + overflows is determined by subsequent lowering passes, which may choose to + wrap around or widen the overflow at their discretion. + + Note that coefficient modulus is contained in `i64` by default, which is signed. + To specify a 64 bit number without intepreting it as a negative number, its container + type should be manually specified like `coefficientModulus=18446744073709551615:i128`. }]; let parameters = (ins @@ -168,6 +179,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus, OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus ); + let genVerifyDecl = 1; let assemblyFormat = "`<` struct(params) `>`"; let builders = [ AttrBuilderWithInferredContext< diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index 93c7f9e687fc7..cd7789a2e9531 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -203,5 +203,34 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { return FloatPolynomialAttr::get(parser.getContext(), result.value()); } +LogicalResult +RingAttr::verify(function_ref emitError, + Type coefficientType, IntegerAttr coefficientModulus, + IntPolynomialAttr polynomialModulus) { + if (coefficientModulus) { + auto coeffIntType = llvm::dyn_cast(coefficientType); + if (!coeffIntType) { + return emitError() << "coefficientModulus specified but coefficientType " + "is not integral"; + } + APInt coeffModValue = coefficientModulus.getValue(); + if (coeffModValue == 0) { + return emitError() << "coefficientModulus should not be 0"; + } + if (coeffModValue.slt(0)) { + return emitError() << "coefficientModulus should be positive"; + } + auto coeffModWidth = (coeffModValue - 1).getActiveBits(); + auto coeffWidth = coeffIntType.getWidth(); + if (coeffModWidth > coeffWidth) { + return emitError() << "coefficientModulus needs bit width of " + << coeffModWidth + << " but coefficientType can only contain " + << coeffWidth << " bits"; + } + } + return success(); +} + } // namespace polynomial } // namespace mlir diff --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir index 4bdfd44fd4d15..cb3216900cb43 100644 --- a/mlir/test/Dialect/Polynomial/attributes.mlir +++ b/mlir/test/Dialect/Polynomial/attributes.mlir @@ -37,3 +37,37 @@ // expected-error@below {{failed to parse Polynomial_RingAttr parameter 'coefficientModulus' which is to be a `::mlir::IntegerAttr`}} // expected-error@below {{expected attribute value}} #ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{coefficientModulus specified but coefficientType is not integral}} +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{coefficientModulus should not be 0}} +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{coefficientModulus should be positive}} +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{coefficientModulus needs bit width of 33 but coefficientType can only contain 32 bits}} +#ring1 = #polynomial.ring + +// ----- + +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{coefficientModulus should be positive}} +#ring1 = #polynomial.ring + +// ----- + +// unfortunately, coefficientModulus of 64bit should be contained in larger type +#ring1 = #polynomial.ring