Skip to content

Commit be4ffb6

Browse files
authored
[Datapath] Custom Partial Product Lowering for computing the Square of the input (#9010)
* Add squarer partial product array reduction * Add tests for square partial products * Address comments
1 parent 6093d7b commit be4ffb6

File tree

3 files changed

+138
-2
lines changed

3 files changed

+138
-2
lines changed

integration_test/circt-synth/datapath-lowering-lec.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,24 @@ hw.module @partial_product_zext(in %a : i3, in %b : i3, out sum : i6) {
2929
hw.output %3 : i6
3030
}
3131

32+
// RUN: circt-lec %t.mlir %s -c1=partial_product_square -c2=partial_product_square --shared-libs=%libz3 | FileCheck %s --check-prefix=SQR4
33+
// SQR4: c1 == c2
34+
hw.module @partial_product_square(in %a : i4, out sum : i4) {
35+
%0:4 = datapath.partial_product %a, %a : (i4, i4) -> (i4, i4, i4, i4)
36+
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4
37+
hw.output %1 : i4
38+
}
39+
40+
// RUN: circt-lec %t.mlir %s -c1=partial_product_square_zext -c2=partial_product_square_zext --shared-libs=%libz3 | FileCheck %s --check-prefix=SQR3_ZEXT
41+
// SQR3_ZEXT: c1 == c2
42+
hw.module @partial_product_square_zext(in %a : i3, out sum : i6) {
43+
%c0_i3 = hw.constant 0 : i3
44+
%0 = comb.concat %c0_i3, %a : i3, i3
45+
%1:3 = datapath.partial_product %0, %0 : (i6, i6) -> (i6, i6, i6)
46+
%2 = comb.add %1#0, %1#1, %1#2 : i6
47+
hw.output %2 : i6
48+
}
49+
3250
// RUN: circt-lec %t.mlir %s -c1=partial_product_sext -c2=partial_product_sext --shared-libs=%libz3 | FileCheck %s --check-prefix=AND3_SEXT
3351
// AND3_SEXT: c1 == c2
3452
hw.module @partial_product_sext(in %a : i3, in %b : i3, out sum : i6) {

lib/Conversion/DatapathToComb/DatapathToComb.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1919
#include "llvm/Support/Debug.h"
2020
#include "llvm/Support/KnownBits.h"
21+
#include <algorithm>
2122

2223
#define DEBUG_TYPE "datapath-to-comb"
2324

@@ -122,6 +123,21 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
122123
return success();
123124
}
124125

126+
// Square partial product array can be reduced to upper triangular array.
127+
// For example: AND array for a 4-bit squarer:
128+
// 0 0 0 a0a3 a0a2 a0a1 a0a0
129+
// 0 0 a1a3 a1a2 a1a1 a1a0 0
130+
// 0 a2a3 a2a2 a2a1 a2a0 0 0
131+
// a3a3 a3a2 a3a1 a3a0 0 0 0
132+
//
133+
// Can be reduced to:
134+
// 0 0 a0a3 a0a2 a0a1 0 a0
135+
// 0 a1a3 a1a2 0 a1 0 0
136+
// a2a3 0 a2 0 0 0 0
137+
// a3 0 0 0 0 0 0
138+
if (a == b)
139+
return lowerSqrAndArray(rewriter, a, op, width);
140+
125141
// Use result rows as a heuristic to guide partial product
126142
// implementation
127143
if (op.getNumResults() > 16 || forceBooth)
@@ -166,6 +182,70 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
166182
return success();
167183
}
168184

185+
static LogicalResult lowerSqrAndArray(PatternRewriter &rewriter, Value a,
186+
PartialProductOp op, unsigned width) {
187+
188+
Location loc = op.getLoc();
189+
SmallVector<Value> aBits = extractBits(rewriter, a);
190+
191+
SmallVector<Value> partialProducts;
192+
partialProducts.reserve(width);
193+
// AND Array Construction - reducing to upper triangle:
194+
// partialProducts[i] = ({a[i],..., a[i]} & a) << i
195+
// optimised to: {a[i] & a[n-1], ..., a[i] & a[i+1], 0, a[i], 0, ..., 0}
196+
assert(op.getNumResults() <= width &&
197+
"Cannot return more results than the operator width");
198+
auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
199+
for (unsigned i = 0; i < op.getNumResults(); ++i) {
200+
SmallVector<Value> row;
201+
row.reserve(width);
202+
203+
if (2 * i >= width) {
204+
// Pad the remaining rows with zeros
205+
auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
206+
partialProducts.push_back(zeroWidth);
207+
continue;
208+
}
209+
210+
if (i > 0) {
211+
auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(2 * i, 0));
212+
row.push_back(shiftBy);
213+
}
214+
row.push_back(aBits[i]);
215+
216+
// Track width of constructed row
217+
unsigned rowWidth = 2 * i + 1;
218+
if (rowWidth < width) {
219+
row.push_back(zeroFalse);
220+
++rowWidth;
221+
}
222+
223+
for (unsigned j = i + 1; j < width; ++j) {
224+
// Stop when we reach the required width
225+
if (rowWidth == width)
226+
break;
227+
228+
// Otherwise pad with zeros or partial product bits
229+
++rowWidth;
230+
// Number of results indicates number of non-zero bits in input
231+
if (j >= op.getNumResults()) {
232+
row.push_back(zeroFalse);
233+
continue;
234+
}
235+
236+
auto ppBit =
237+
rewriter.createOrFold<comb::AndOp>(loc, aBits[i], aBits[j]);
238+
row.push_back(ppBit);
239+
}
240+
std::reverse(row.begin(), row.end());
241+
auto ppRow = comb::ConcatOp::create(rewriter, loc, row);
242+
partialProducts.push_back(ppRow);
243+
}
244+
245+
rewriter.replaceOp(op, partialProducts);
246+
return success();
247+
}
248+
169249
static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
170250
Value b, PartialProductOp op,
171251
unsigned width) {
@@ -370,8 +450,8 @@ struct DatapathPosPartialProductOpConversion
370450
unsigned width) {
371451

372452
Location loc = op.getLoc();
373-
// Encode (a+b) by implementing a half-adder - then note the following fact
374-
// carry[i] & save[i] == false
453+
// Encode (a+b) by implementing a half-adder - then note the following
454+
// fact carry[i] & save[i] == false
375455
auto carry = rewriter.createOrFold<comb::AndOp>(loc, a, b);
376456
auto save = rewriter.createOrFold<comb::XorOp>(loc, a, b);
377457

test/Conversion/DatapathToComb/datapath-to-comb.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,44 @@ hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, o
5656
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
5757
}
5858

59+
// CHECK-LABEL: @partial_product_square
60+
hw.module @partial_product_square(in %a : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
61+
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
62+
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
63+
// CHECK-NEXT: %false = hw.constant false
64+
// CHECK-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i3) -> i1
65+
// CHECK-NEXT: %[[A1:.+]] = comb.extract %a from 1 : (i3) -> i1
66+
// CHECK-NEXT: %[[A01:.+]] = comb.and %[[A0]], %[[A1]] : i1
67+
// CHECK-NEXT: %[[PP0:.+]] = comb.concat %[[A01]], %false, %[[A0]] : i1, i1, i1
68+
// CHECK-NEXT: %[[PP1:.+]] = comb.concat %[[A1]], %c0_i2 : i1, i2
69+
// CHECK-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i3 : i3, i3, i3
70+
%0:3 = datapath.partial_product %a, %a : (i3, i3) -> (i3, i3, i3)
71+
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
72+
}
73+
74+
// CHECK-LABEL: @partial_product_square_zext
75+
hw.module @partial_product_square_zext(in %a : i3, out pp0 : i6, out pp1 : i6, out pp2 : i6) {
76+
// CHECK-NEXT: %c0_i4 = hw.constant 0 : i4
77+
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
78+
// CHECK-NEXT: %false = hw.constant false
79+
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
80+
// CHECK-NEXT: %[[AEXT:.+]] = comb.concat %c0_i3, %a : i3, i3
81+
// CHECK-NEXT: %[[A0:.+]] = comb.extract %[[AEXT]] from 0 : (i6) -> i1
82+
// CHECK-NEXT: %[[A1:.+]] = comb.extract %[[AEXT]] from 1 : (i6) -> i1
83+
// CHECK-NEXT: %[[A2:.+]] = comb.extract %[[AEXT]] from 2 : (i6) -> i1
84+
// CHECK-NEXT: %[[A01:.+]] = comb.and %[[A0]], %[[A1]] : i1
85+
// CHECK-NEXT: %[[A02:.+]] = comb.and %[[A0]], %[[A2]] : i1
86+
// CHECK-NEXT: %[[PP0:.+]] = comb.concat %false, %false, %[[A02]], %[[A01]], %false, %[[A0]] : i1, i1, i1, i1, i1, i1
87+
// CHECK-NEXT: %[[A12:.+]] = comb.and %[[A1]], %[[A2]] : i1
88+
// CHECK-NEXT: %[[PP1:.+]] = comb.concat %false, %[[A12]], %false, %[[A1]], %c0_i2 : i1, i1, i1, i1, i2
89+
// CHECK-NEXT: %[[PP2:.+]] = comb.concat %false, %[[A2]], %c0_i4 : i1, i1, i4
90+
// CHECK-NEXT: hw.output %[[PP0]], %[[PP1]], %[[PP2]] : i6, i6, i6
91+
%c0_i3 = hw.constant 0 : i3
92+
%0 = comb.concat %c0_i3, %a : i3, i3
93+
%1:3 = datapath.partial_product %0, %0 : (i6, i6) -> (i6, i6, i6)
94+
hw.output %1#0, %1#1, %1#2 : i6, i6, i6
95+
}
96+
5997
// CHECK-LABEL: @partial_product_booth
6098
// FORCE-BOOTH-LABEL: @partial_product_booth
6199
// Constants

0 commit comments

Comments
 (0)