|
| 1 | +// RUN: mlir-opt -canonicalize %s | FileCheck %s |
| 2 | +#ntt_poly = #polynomial.int_polynomial<-1 + x**8> |
| 3 | +#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31> |
| 4 | +!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring> |
| 5 | +!tensor_ty = tensor<8xi32, #ntt_ring> |
| 6 | + |
| 7 | +// CHECK-LABEL: @test_canonicalize_intt_after_ntt |
| 8 | +// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]] |
| 9 | +func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty { |
| 10 | + // CHECK-NOT: polynomial.ntt |
| 11 | + // CHECK-NOT: polynomial.intt |
| 12 | + // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]] |
| 13 | + %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty |
| 14 | + %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty |
| 15 | + %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty |
| 16 | + // CHECK: return %[[RESULT]] : [[T]] |
| 17 | + return %p2 : !ntt_poly_ty |
| 18 | +} |
| 19 | + |
| 20 | +// CHECK-LABEL: @test_canonicalize_ntt_after_intt |
| 21 | +// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]] |
| 22 | +func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty { |
| 23 | + // CHECK-NOT: polynomial.intt |
| 24 | + // CHECK-NOT: polynomial.ntt |
| 25 | + // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]] |
| 26 | + %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty |
| 27 | + %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty |
| 28 | + %t2 = arith.addi %t1, %t1 : !tensor_ty |
| 29 | + // CHECK: return %[[RESULT]] : [[T]] |
| 30 | + return %t2 : !tensor_ty |
| 31 | +} |
| 32 | + |
| 33 | +#cycl_2048 = #polynomial.int_polynomial<1 + x**1024> |
| 34 | +#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048> |
| 35 | +!sub_ty = !polynomial.polynomial<ring=#ring> |
| 36 | + |
| 37 | +// CHECK-LABEL: test_canonicalize_sub |
| 38 | +// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] { |
| 39 | +func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty { |
| 40 | + %0 = polynomial.sub %poly0, %poly1 : !sub_ty |
| 41 | + // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32 |
| 42 | + // CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]] |
| 43 | + // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]] |
| 44 | + return %0 : !sub_ty |
| 45 | +} |
0 commit comments