1
1
// RUN: mlir-opt -canonicalize %s | FileCheck %s
2
2
#ntt_poly = #polynomial.int_polynomial <-1 + x **8 >
3
- #ntt_ring = #polynomial.ring <coefficientType =i32 , coefficientModulus =256 , polynomialModulus =#ntt_poly , primitiveRoot =31 >
3
+ #ntt_ring = #polynomial.ring <coefficientType =i32 , coefficientModulus =256 , polynomialModulus =#ntt_poly >
4
+ #root = #polynomial.primitive_root <value =31 :i32 , degree =8 :index >
4
5
!ntt_poly_ty = !polynomial.polynomial <ring =#ntt_ring >
5
6
!tensor_ty = tensor <8 xi32 , #ntt_ring >
6
7
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
10
11
// CHECK-NOT: polynomial.ntt
11
12
// CHECK-NOT: polynomial.intt
12
13
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
13
- %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
14
- %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
14
+ %t0 = polynomial.ntt %p0 { root = #root } : !ntt_poly_ty -> !tensor_ty
15
+ %p1 = polynomial.intt %t0 { root = #root } : !tensor_ty -> !ntt_poly_ty
15
16
%p2 = polynomial.add %p1 , %p1 : !ntt_poly_ty
16
17
// CHECK: return %[[RESULT]] : [[T]]
17
18
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
23
24
// CHECK-NOT: polynomial.intt
24
25
// CHECK-NOT: polynomial.ntt
25
26
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
26
- %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
27
- %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
27
+ %p0 = polynomial.intt %t0 { root = #root } : !tensor_ty -> !ntt_poly_ty
28
+ %t1 = polynomial.ntt %p0 { root = #root } : !ntt_poly_ty -> !tensor_ty
28
29
%t2 = arith.addi %t1 , %t1 : !tensor_ty
29
30
// CHECK: return %[[RESULT]] : [[T]]
30
31
return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
51
52
func.func @test_canonicalize_fold_add_through_ntt (
52
53
%poly0 : !ntt_poly_ty ,
53
54
%poly1 : !ntt_poly_ty ) -> !ntt_poly_ty {
54
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
55
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
55
+ %0 = polynomial.ntt %poly0 { root = #root } : !ntt_poly_ty -> !tensor_ty
56
+ %1 = polynomial.ntt %poly1 { root = #root } : !ntt_poly_ty -> !tensor_ty
56
57
%a_plus_b = arith.addi %0 , %1 : !tensor_ty
57
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
58
+ %out = polynomial.intt %a_plus_b { root = #root } : !tensor_ty -> !ntt_poly_ty
58
59
return %out : !ntt_poly_ty
59
60
}
60
61
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
65
66
func.func @test_canonicalize_fold_add_through_intt (
66
67
%tensor0 : !tensor_ty ,
67
68
%tensor1 : !tensor_ty ) -> !tensor_ty {
68
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
69
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
69
+ %0 = polynomial.intt %tensor0 { root = #root } : !tensor_ty -> !ntt_poly_ty
70
+ %1 = polynomial.intt %tensor1 { root = #root } : !tensor_ty -> !ntt_poly_ty
70
71
%a_plus_b = polynomial.add %0 , %1 : !ntt_poly_ty
71
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
72
+ %out = polynomial.ntt %a_plus_b { root = #root } : !ntt_poly_ty -> !tensor_ty
72
73
return %out : !tensor_ty
73
74
}
74
75
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
80
81
func.func @test_canonicalize_fold_sub_through_ntt (
81
82
%poly0 : !ntt_poly_ty ,
82
83
%poly1 : !ntt_poly_ty ) -> !ntt_poly_ty {
83
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
84
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
84
+ %0 = polynomial.ntt %poly0 { root = #root } : !ntt_poly_ty -> !tensor_ty
85
+ %1 = polynomial.ntt %poly1 { root = #root } : !ntt_poly_ty -> !tensor_ty
85
86
%a_plus_b = arith.subi %0 , %1 : !tensor_ty
86
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
87
+ %out = polynomial.intt %a_plus_b { root = #root } : !tensor_ty -> !ntt_poly_ty
87
88
return %out : !ntt_poly_ty
88
89
}
89
90
@@ -94,9 +95,23 @@ func.func @test_canonicalize_fold_sub_through_ntt(
94
95
func.func @test_canonicalize_fold_sub_through_intt (
95
96
%tensor0 : !tensor_ty ,
96
97
%tensor1 : !tensor_ty ) -> !tensor_ty {
97
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
98
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
98
+ %0 = polynomial.intt %tensor0 { root = #root } : !tensor_ty -> !ntt_poly_ty
99
+ %1 = polynomial.intt %tensor1 { root = #root } : !tensor_ty -> !ntt_poly_ty
99
100
%a_plus_b = polynomial.sub %0 , %1 : !ntt_poly_ty
100
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
101
+ %out = polynomial.ntt %a_plus_b { root = #root } : !ntt_poly_ty -> !tensor_ty
101
102
return %out : !tensor_ty
102
103
}
104
+
105
+
106
+ // CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
107
+ // CHECK: arith.addi
108
+ func.func @test_canonicalize_do_not_fold_different_roots (
109
+ %poly0 : !ntt_poly_ty ,
110
+ %poly1 : !ntt_poly_ty ) -> !ntt_poly_ty {
111
+ %0 = polynomial.ntt %poly0 {root =#polynomial.primitive_root <value =31 :i32 , degree =8 :index >} : !ntt_poly_ty -> !tensor_ty
112
+ %1 = polynomial.ntt %poly1 {root =#polynomial.primitive_root <value =33 :i32 , degree =8 :index >} : !ntt_poly_ty -> !tensor_ty
113
+ %a_plus_b = arith.addi %0 , %1 : !tensor_ty
114
+ %out = polynomial.intt %a_plus_b {root =#root } : !tensor_ty -> !ntt_poly_ty
115
+ return %out : !ntt_poly_ty
116
+ }
117
+
0 commit comments