Skip to content

Commit 932bef2

Browse files
authored
Poly canonicalization (#91410)
Adds simple canonicalization rules to the polynomial dialect. Mainly to get the boilerplate incorporated before more substantial canonicalization patterns are added. --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent a4ad052 commit 932bef2

File tree

6 files changed

+138
-0
lines changed

6 files changed

+138
-0
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

+3
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
245245
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
246246
```
247247
}];
248+
let hasCanonicalizer = 1;
248249
}
249250

250251
def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
@@ -480,6 +481,7 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
480481
let arguments = (ins Polynomial_PolynomialType:$input);
481482
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
482483
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
484+
let hasCanonicalizer = 1;
483485
let hasVerifier = 1;
484486
}
485487

@@ -498,6 +500,7 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
498500
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
499501
let results = (outs Polynomial_PolynomialType:$output);
500502
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
503+
let hasCanonicalizer = 1;
501504
let hasVerifier = 1;
502505
}
503506

mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
set(LLVM_TARGET_DEFINITIONS PolynomialCanonicalization.td)
2+
mlir_tablegen(PolynomialCanonicalization.inc -gen-rewriters)
3+
add_public_tablegen_target(MLIRPolynomialCanonicalizationIncGen)
4+
15
add_mlir_dialect_library(MLIRPolynomialDialect
26
Polynomial.cpp
37
PolynomialAttributes.cpp
@@ -10,6 +14,7 @@ add_mlir_dialect_library(MLIRPolynomialDialect
1014
DEPENDS
1115
MLIRPolynomialIncGen
1216
MLIRPolynomialAttributesIncGen
17+
MLIRPolynomialCanonicalizationIncGen
1318
MLIRBuiltinAttributesIncGen
1419

1520
LINK_LIBS PUBLIC
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef POLYNOMIAL_CANONICALIZATION
10+
#define POLYNOMIAL_CANONICALIZATION
11+
12+
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
13+
include "mlir/Dialect/Arith/IR/ArithOps.td"
14+
include "mlir/IR/OpBase.td"
15+
include "mlir/IR/PatternBase.td"
16+
17+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
18+
// ring coefficient type.
19+
def getMinusOne
20+
: NativeCodeCall<
21+
"$_builder.getIntegerAttr("
22+
"cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
23+
24+
def SubAsAdd : Pat<
25+
(Polynomial_SubOp $f, $g),
26+
(Polynomial_AddOp $f,
27+
(Polynomial_MulScalarOp $g,
28+
(Arith_ConstantOp (getMinusOne $g))))>;
29+
30+
def INTTAfterNTT : Pat<
31+
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
32+
(replaceWithValue $poly),
33+
[]
34+
>;
35+
36+
def NTTAfterINTT : Pat<
37+
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
38+
(replaceWithValue $tensor),
39+
[]
40+
>;
41+
42+
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1112
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
1213
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
1314
#include "mlir/IR/Builders.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "mlir/IR/Dialect.h"
17+
#include "mlir/IR/PatternMatch.h"
1618
#include "mlir/Support/LogicalResult.h"
1719
#include "llvm/ADT/APInt.h"
1820

@@ -183,3 +185,26 @@ LogicalResult INTTOp::verify() {
183185
auto ring = getOutput().getType().getRing();
184186
return verifyNTTOp(this->getOperation(), ring, tensorType);
185187
}
188+
189+
//===----------------------------------------------------------------------===//
190+
// TableGen'd canonicalization patterns
191+
//===----------------------------------------------------------------------===//
192+
193+
namespace {
194+
#include "PolynomialCanonicalization.inc"
195+
} // namespace
196+
197+
void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
198+
MLIRContext *context) {
199+
results.add<SubAsAdd>(context);
200+
}
201+
202+
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
203+
MLIRContext *context) {
204+
results.add<NTTAfterINTT>(context);
205+
}
206+
207+
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
208+
MLIRContext *context) {
209+
results.add<INTTAfterNTT>(context);
210+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: mlir-opt -canonicalize %s | FileCheck %s
2+
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
3+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
4+
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
5+
!tensor_ty = tensor<8xi32, #ntt_ring>
6+
7+
// CHECK-LABEL: @test_canonicalize_intt_after_ntt
8+
// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]]
9+
func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty {
10+
// CHECK-NOT: polynomial.ntt
11+
// CHECK-NOT: polynomial.intt
12+
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
13+
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
14+
%p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
15+
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
16+
// CHECK: return %[[RESULT]] : [[T]]
17+
return %p2 : !ntt_poly_ty
18+
}
19+
20+
// CHECK-LABEL: @test_canonicalize_ntt_after_intt
21+
// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]]
22+
func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
23+
// CHECK-NOT: polynomial.intt
24+
// CHECK-NOT: polynomial.ntt
25+
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
26+
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
27+
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
28+
%t2 = arith.addi %t1, %t1 : !tensor_ty
29+
// CHECK: return %[[RESULT]] : [[T]]
30+
return %t2 : !tensor_ty
31+
}
32+
33+
#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
34+
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
35+
!sub_ty = !polynomial.polynomial<ring=#ring>
36+
37+
// CHECK-LABEL: test_canonicalize_sub
38+
// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] {
39+
func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty {
40+
%0 = polynomial.sub %poly0, %poly1 : !sub_ty
41+
// CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
42+
// CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]]
43+
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
44+
return %0 : !sub_ty
45+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

+18
Original file line numberDiff line numberDiff line change
@@ -6728,6 +6728,7 @@ cc_library(
67286728
":IR",
67296729
":InferTypeOpInterface",
67306730
":PolynomialAttributesIncGen",
6731+
":PolynomialCanonicalizationIncGen",
67316732
":PolynomialIncGen",
67326733
":Support",
67336734
"//llvm:Support",
@@ -6818,6 +6819,23 @@ gentbl_cc_library(
68186819
deps = [":PolynomialTdFiles"],
68196820
)
68206821

6822+
gentbl_cc_library(
6823+
name = "PolynomialCanonicalizationIncGen",
6824+
strip_include_prefix = "include/mlir/Dialect/Polynomial/IR",
6825+
tbl_outs = [
6826+
(
6827+
["-gen-rewriters"],
6828+
"include/mlir/Dialect/Polynomial/IR/PolynomialCanonicalization.inc",
6829+
),
6830+
],
6831+
tblgen = ":mlir-tblgen",
6832+
td_file = "lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td",
6833+
deps = [
6834+
":ArithOpsTdFiles",
6835+
":PolynomialTdFiles",
6836+
],
6837+
)
6838+
68216839
td_library(
68226840
name = "SPIRVOpsTdFiles",
68236841
srcs = glob(["include/mlir/Dialect/SPIRV/IR/*.td"]),

0 commit comments

Comments
 (0)