|
18 | 18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | 19 | #include "llvm/Support/Debug.h" |
20 | 20 | #include "llvm/Support/KnownBits.h" |
| 21 | +#include <algorithm> |
21 | 22 |
|
22 | 23 | #define DEBUG_TYPE "datapath-to-comb" |
23 | 24 |
|
@@ -122,6 +123,21 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> { |
122 | 123 | return success(); |
123 | 124 | } |
124 | 125 |
|
| 126 | + // Square partial product array can be reduced to upper triangular array. |
| 127 | + // For example: AND array for a 4-bit squarer: |
| 128 | + // 0 0 0 a0a3 a0a2 a0a1 a0a0 |
| 129 | + // 0 0 a1a3 a1a2 a1a1 a1a0 0 |
| 130 | + // 0 a2a3 a2a2 a2a1 a2a0 0 0 |
| 131 | + // a3a3 a3a2 a3a1 a3a0 0 0 0 |
| 132 | + // |
| 133 | + // Can be reduced to: |
| 134 | + // 0 0 a0a3 a0a2 a0a1 0 a0 |
| 135 | + // 0 a1a3 a1a2 0 a1 0 0 |
| 136 | + // a2a3 0 a2 0 0 0 0 |
| 137 | + // a3 0 0 0 0 0 0 |
| 138 | + if (a == b) |
| 139 | + return lowerSqrAndArray(rewriter, a, op, width); |
| 140 | + |
125 | 141 | // Use result rows as a heuristic to guide partial product |
126 | 142 | // implementation |
127 | 143 | if (op.getNumResults() > 16 || forceBooth) |
@@ -166,6 +182,70 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> { |
166 | 182 | return success(); |
167 | 183 | } |
168 | 184 |
|
| 185 | + static LogicalResult lowerSqrAndArray(PatternRewriter &rewriter, Value a, |
| 186 | + PartialProductOp op, unsigned width) { |
| 187 | + |
| 188 | + Location loc = op.getLoc(); |
| 189 | + SmallVector<Value> aBits = extractBits(rewriter, a); |
| 190 | + |
| 191 | + SmallVector<Value> partialProducts; |
| 192 | + partialProducts.reserve(width); |
| 193 | + // AND Array Construction - reducing to upper triangle: |
| 194 | + // partialProducts[i] = ({a[i],..., a[i]} & a) << i |
| 195 | + // optimised to: {a[i] & a[n-1], ..., a[i] & a[i+1], 0, a[i], 0, ..., 0} |
| 196 | + assert(op.getNumResults() <= width && |
| 197 | + "Cannot return more results than the operator width"); |
| 198 | + auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0)); |
| 199 | + for (unsigned i = 0; i < op.getNumResults(); ++i) { |
| 200 | + SmallVector<Value> row; |
| 201 | + row.reserve(width); |
| 202 | + |
| 203 | + if (2 * i >= width) { |
| 204 | + // Pad the remaining rows with zeros |
| 205 | + auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0)); |
| 206 | + partialProducts.push_back(zeroWidth); |
| 207 | + continue; |
| 208 | + } |
| 209 | + |
| 210 | + if (i > 0) { |
| 211 | + auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(2 * i, 0)); |
| 212 | + row.push_back(shiftBy); |
| 213 | + } |
| 214 | + row.push_back(aBits[i]); |
| 215 | + |
| 216 | + // Track width of constructed row |
| 217 | + unsigned rowWidth = 2 * i + 1; |
| 218 | + if (rowWidth < width) { |
| 219 | + row.push_back(zeroFalse); |
| 220 | + ++rowWidth; |
| 221 | + } |
| 222 | + |
| 223 | + for (unsigned j = i + 1; j < width; ++j) { |
| 224 | + // Stop when we reach the required width |
| 225 | + if (rowWidth == width) |
| 226 | + break; |
| 227 | + |
| 228 | + // Otherwise pad with zeros or partial product bits |
| 229 | + ++rowWidth; |
| 230 | + // Number of results indicates number of non-zero bits in input |
| 231 | + if (j >= op.getNumResults()) { |
| 232 | + row.push_back(zeroFalse); |
| 233 | + continue; |
| 234 | + } |
| 235 | + |
| 236 | + auto ppBit = |
| 237 | + rewriter.createOrFold<comb::AndOp>(loc, aBits[i], aBits[j]); |
| 238 | + row.push_back(ppBit); |
| 239 | + } |
| 240 | + std::reverse(row.begin(), row.end()); |
| 241 | + auto ppRow = comb::ConcatOp::create(rewriter, loc, row); |
| 242 | + partialProducts.push_back(ppRow); |
| 243 | + } |
| 244 | + |
| 245 | + rewriter.replaceOp(op, partialProducts); |
| 246 | + return success(); |
| 247 | + } |
| 248 | + |
169 | 249 | static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a, |
170 | 250 | Value b, PartialProductOp op, |
171 | 251 | unsigned width) { |
@@ -370,8 +450,8 @@ struct DatapathPosPartialProductOpConversion |
370 | 450 | unsigned width) { |
371 | 451 |
|
372 | 452 | Location loc = op.getLoc(); |
373 | | - // Encode (a+b) by implementing a half-adder - then note the following fact |
374 | | - // carry[i] & save[i] == false |
| 453 | + // Encode (a+b) by implementing a half-adder - then note the following |
| 454 | + // fact carry[i] & save[i] == false |
375 | 455 | auto carry = rewriter.createOrFold<comb::AndOp>(loc, a, b); |
376 | 456 | auto save = rewriter.createOrFold<comb::XorOp>(loc, a, b); |
377 | 457 |
|
|
0 commit comments