Skip to content

Commit 1f46729

Browse files
authored
[polynomial] Move primitive root attribute to ntt/intt ops. (#93227)
Better design to put semantics on the ops, and in this case the ntt/intt op can lower in multiple ways depending on the polynomial ring modulus (it can need an nth root of unity for cyclic polymul -> ntt, or a 2nth root for negacyclic polymul -> ntt) --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent b5db2e1 commit 1f46729

File tree

7 files changed

+127
-97
lines changed

7 files changed

+127
-97
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

+11-4
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
277277
Polynomial_TypedIntPolynomialAttr
278278
]>;
279279

280-
// Not deriving from Polynomial_Op due to need for custom assembly format
281280
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
282281
[Pure, InferTypeOpAdaptor]> {
283282
let summary = "Define a constant polynomial via an attribute.";
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
312311

313312
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
314313

315-
The choice of primitive root is determined by subsequent lowerings.
314+
The choice of primitive root may be optionally specified.
316315
}];
317-
let arguments = (ins Polynomial_PolynomialType:$input);
316+
let arguments = (ins
317+
Polynomial_PolynomialType:$input,
318+
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
319+
);
318320
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
319321
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
320322
let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
332334
output polynomial at powers of a primitive `n`-th root of unity (see
333335
`polynomial.ntt`). The ring of the polynomial is taken from the required
334336
encoding attribute of the tensor.
337+
338+
The choice of primitive root may be optionally specified.
335339
}];
336-
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
340+
let arguments = (
341+
ins RankedTensorOf<[AnyInteger]>:$input,
342+
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
343+
);
337344
let results = (outs Polynomial_PolynomialType:$output);
338345
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
339346
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

+27-6
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
166166
let parameters = (ins
167167
"Type": $coefficientType,
168168
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
169-
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
170-
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
169+
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
171170
);
172171
let assemblyFormat = "`<` struct(params) `>`";
173172
let builders = [
174173
AttrBuilderWithInferredContext<
175174
(ins "::mlir::Type":$coefficientTy,
176175
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
177-
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
178-
CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
176+
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
179177
return $_get(
180178
coefficientTy.getContext(),
181179
coefficientTy,
182180
coefficientModulusAttr,
183-
polynomialModulusAttr,
184-
primitiveRootAttr);
181+
polynomialModulusAttr);
185182
}]>,
186183
];
187184
}
188185

186+
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
187+
let summary = "an attribute containing an integer and its degree as a root of unity";
188+
let description = [{
189+
A primitive root attribute stores an integer root `value` and an integer
190+
`degree`, corresponding to a primitive root of unity of the given degree in
191+
an unspecified ring.
192+
193+
This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
194+
to specify the root of unity used in lowering the transform.
195+
196+
Example:
197+
198+
```mlir
199+
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
200+
```
201+
}];
202+
let parameters = (ins
203+
"::mlir::IntegerAttr":$value,
204+
"::mlir::IntegerAttr":$degree
205+
);
206+
let assemblyFormat = "`<` struct(params) `>`";
207+
}
208+
209+
189210
#endif // POLYNOMIAL_ATTRIBUTES

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

+22-20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
1717

1818
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
1919

20+
def Equal : Constraint<CPred<"$0 == $1">>;
21+
2022
// Get a -1 integer attribute of the same type as the polynomial SSA value's
2123
// ring coefficient type.
2224
def getMinusOne
@@ -31,51 +33,51 @@ def SubAsAdd : Pat<
3133
(Arith_ConstantOp (getMinusOne $g))))>;
3234

3335
def INTTAfterNTT : Pat<
34-
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
36+
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
3537
(replaceWithValue $poly),
36-
[]
38+
[(Equal $r1, $r2)]
3739
>;
3840

3941
def NTTAfterINTT : Pat<
40-
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
42+
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
4143
(replaceWithValue $tensor),
42-
[]
44+
[(Equal $r1, $r2)]
4345
>;
4446

4547
// NTTs are expensive, and addition in coefficient or NTT domain should be
4648
// equivalently expensive, so reducing the number of NTTs is optimal.
4749
// ntt(a) + ntt(b) -> ntt(a + b)
4850
def NTTOfAdd : Pat<
4951
(Arith_AddIOp
50-
(Polynomial_NTTOp $p1),
51-
(Polynomial_NTTOp $p2),
52+
(Polynomial_NTTOp $p1, $r1),
53+
(Polynomial_NTTOp $p2, $r2),
5254
$overflow),
53-
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
54-
[]
55+
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
56+
[(Equal $r1, $r2)]
5557
>;
5658
// intt(a) + intt(b) -> intt(a + b)
5759
def INTTOfAdd : Pat<
5860
(Polynomial_AddOp
59-
(Polynomial_INTTOp $t1),
60-
(Polynomial_INTTOp $t2)),
61-
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
62-
[]
61+
(Polynomial_INTTOp $t1, $r1),
62+
(Polynomial_INTTOp $t2, $r2)),
63+
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
64+
[(Equal $r1, $r2)]
6365
>;
6466
// repeated for sub
6567
def NTTOfSub : Pat<
6668
(Arith_SubIOp
67-
(Polynomial_NTTOp $p1),
68-
(Polynomial_NTTOp $p2),
69+
(Polynomial_NTTOp $p1, $r1),
70+
(Polynomial_NTTOp $p2, $r2),
6971
$overflow),
70-
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
71-
[]
72+
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
73+
[(Equal $r1, $r2)]
7274
>;
7375
def INTTOfSub : Pat<
7476
(Polynomial_SubOp
75-
(Polynomial_INTTOp $t1),
76-
(Polynomial_INTTOp $t2)),
77-
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
78-
[]
77+
(Polynomial_INTTOp $t1, $r1),
78+
(Polynomial_INTTOp $t2, $r2)),
79+
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
80+
[(Equal $r1, $r2)]
7981
>;
8082

8183
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

+20-21
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
108108
}
109109

110110
/// Test if a value is a primitive nth root of unity modulo cmod.
111-
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
111+
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
112112
const APInt &cmod) {
113113
// Root bitwidth may be 1 less then cmod.
114114
APInt r = APInt(root).zext(cmod.getBitWidth());
115115
assert(r.ule(cmod) && "root must be less than cmod");
116+
unsigned upperBound = n.getZExtValue();
116117

117118
APInt a = r;
118-
for (size_t k = 1; k < n; k++) {
119+
for (size_t k = 1; k < upperBound; k++) {
119120
if (a.isOne())
120121
return false;
121122
a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
126127
/// Verify that the types involved in an NTT or INTT operation are
127128
/// compatible.
128129
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
129-
RankedTensorType tensorType) {
130+
RankedTensorType tensorType,
131+
std::optional<PrimitiveRootAttr> root) {
130132
Attribute encoding = tensorType.getEncoding();
131133
if (!encoding) {
132134
return op->emitOpError()
@@ -157,33 +159,30 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
157159
return diag;
158160
}
159161

160-
if (!ring.getPrimitiveRoot()) {
161-
return op->emitOpError()
162-
<< "ring type " << ring << " does not provide a primitive root "
163-
<< "of unity, which is required to express an NTT";
164-
}
165-
166-
if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
167-
ring.getCoefficientModulus().getValue())) {
168-
return op->emitOpError()
169-
<< "ring type " << ring << " has a primitiveRoot attribute '"
170-
<< ring.getPrimitiveRoot()
171-
<< "' that is not a primitive root of the coefficient ring";
162+
if (root.has_value()) {
163+
APInt rootValue = root.value().getValue().getValue();
164+
APInt rootDegree = root.value().getDegree().getValue();
165+
APInt cmod = ring.getCoefficientModulus().getValue();
166+
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
167+
return op->emitOpError()
168+
<< "provided root " << rootValue.getZExtValue()
169+
<< " is not a primitive root "
170+
<< "of unity mod " << cmod.getZExtValue()
171+
<< ", with the specified degree " << rootDegree.getZExtValue();
172+
}
172173
}
173174

174175
return success();
175176
}
176177

177178
LogicalResult NTTOp::verify() {
178-
auto ring = getInput().getType().getRing();
179-
auto tensorType = getOutput().getType();
180-
return verifyNTTOp(this->getOperation(), ring, tensorType);
179+
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
180+
getOutput().getType(), getRoot());
181181
}
182182

183183
LogicalResult INTTOp::verify() {
184-
auto tensorType = getInput().getType();
185-
auto ring = getOutput().getType().getRing();
186-
return verifyNTTOp(this->getOperation(), ring, tensorType);
184+
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
185+
getInput().getType(), getRoot());
187186
}
188187

189188
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {

mlir/test/Dialect/Polynomial/canonicalization.mlir

+32-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt -canonicalize %s | FileCheck %s
22
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
3-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
3+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
4+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
45
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
56
!tensor_ty = tensor<8xi32, #ntt_ring>
67

@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
1011
// CHECK-NOT: polynomial.ntt
1112
// CHECK-NOT: polynomial.intt
1213
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
13-
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
14-
%p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
14+
%t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
15+
%p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
1516
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
1617
// CHECK: return %[[RESULT]] : [[T]]
1718
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
2324
// CHECK-NOT: polynomial.intt
2425
// CHECK-NOT: polynomial.ntt
2526
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
26-
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
27-
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
27+
%p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
28+
%t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
2829
%t2 = arith.addi %t1, %t1 : !tensor_ty
2930
// CHECK: return %[[RESULT]] : [[T]]
3031
return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
5152
func.func @test_canonicalize_fold_add_through_ntt(
5253
%poly0 : !ntt_poly_ty,
5354
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
54-
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
55-
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
55+
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
56+
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
5657
%a_plus_b = arith.addi %0, %1 : !tensor_ty
57-
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
58+
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
5859
return %out : !ntt_poly_ty
5960
}
6061

@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
6566
func.func @test_canonicalize_fold_add_through_intt(
6667
%tensor0 : !tensor_ty,
6768
%tensor1 : !tensor_ty) -> !tensor_ty {
68-
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
69-
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
69+
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
70+
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
7071
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
71-
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
72+
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
7273
return %out : !tensor_ty
7374
}
7475

@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
8081
func.func @test_canonicalize_fold_sub_through_ntt(
8182
%poly0 : !ntt_poly_ty,
8283
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
83-
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
84-
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
84+
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
85+
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
8586
%a_plus_b = arith.subi %0, %1 : !tensor_ty
86-
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
87+
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
8788
return %out : !ntt_poly_ty
8889
}
8990

@@ -94,9 +95,23 @@ func.func @test_canonicalize_fold_sub_through_ntt(
9495
func.func @test_canonicalize_fold_sub_through_intt(
9596
%tensor0 : !tensor_ty,
9697
%tensor1 : !tensor_ty) -> !tensor_ty {
97-
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
98-
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
98+
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
99+
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
99100
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
100-
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
101+
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
101102
return %out : !tensor_ty
102103
}
104+
105+
106+
// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
107+
// CHECK: arith.addi
108+
func.func @test_canonicalize_do_not_fold_different_roots(
109+
%poly0 : !ntt_poly_ty,
110+
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
111+
%0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
112+
%1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
113+
%a_plus_b = arith.addi %0, %1 : !tensor_ty
114+
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
115+
return %out : !ntt_poly_ty
116+
}
117+

mlir/test/Dialect/Polynomial/ops.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
1212

1313
#ideal = #polynomial.int_polynomial<-1 + x**1024>
14-
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
14+
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
1515
!poly_ty = !polynomial.polynomial<ring=#ring>
1616

1717
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
18-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
18+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
1919
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
2020

2121
module {
@@ -91,12 +91,12 @@ module {
9191
}
9292

9393
func.func @test_ntt(%0 : !ntt_poly_ty) {
94-
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
94+
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
9595
return
9696
}
9797

9898
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
99-
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
99+
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
100100
return
101101
}
102102
}

0 commit comments

Comments
 (0)