13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
15
15
#include " mlir/Dialect/GPU/Transforms/Passes.h"
16
+ #include " mlir/Dialect/GPU/Transforms/Utils.h"
16
17
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
+ #include " mlir/IR/BuiltinTypes.h"
17
19
#include " mlir/IR/Location.h"
18
20
#include " mlir/IR/PatternMatch.h"
21
+ #include " mlir/IR/TypeUtilities.h"
19
22
#include " mlir/Support/LogicalResult.h"
20
23
#include " llvm/Support/FormatVariadic.h"
21
24
#include " llvm/Support/MathExtras.h"
22
25
#include < cassert>
26
+ #include < cstdint>
23
27
24
28
using namespace mlir ;
25
29
@@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
58
62
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth ();
59
63
if (elemBitwidth >= maxShuffleBitwidth)
60
64
return rewriter.notifyMatchFailure (
61
- op, llvm::formatv (" element type too large {0}, cannot break down "
65
+ op, llvm::formatv (" element type too large ( {0}) , cannot break down "
62
66
" into vectors of bitwidth {1} or less" ,
63
67
elemBitwidth, maxShuffleBitwidth));
64
68
@@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
139
143
}
140
144
};
141
145
146
+ // / Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
147
+ // / and `unpackFn` to convert to the native shuffle type and to the reduction
148
+ // / type, respectively. For example, with `input` of type `f16`, `packFn` could
149
+ // / build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
150
+ // / would cast it back to `f16` to perform arithmetic reduction on. Assumes that
151
+ // / the subgroup is `subgroupSize` lanes wide and reduces across all of them.
152
+ static Value createSubgroupShuffleReduction (
153
+ OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
154
+ unsigned subgroupSize, function_ref<Value(Value)> packFn,
155
+ function_ref<Value(Value)> unpackFn) {
156
+ assert (llvm::isPowerOf2_32 (subgroupSize));
157
+ // Lane value always stays in the original type. We use it to perform arith
158
+ // reductions.
159
+ Value laneVal = input;
160
+ // Parallel reduction using butterfly shuffles.
161
+ for (unsigned i = 1 ; i < subgroupSize; i <<= 1 ) {
162
+ Value shuffled = builder
163
+ .create <gpu::ShuffleOp>(loc, packFn (laneVal), i,
164
+ /* width=*/ subgroupSize,
165
+ /* mode=*/ gpu::ShuffleMode::XOR)
166
+ .getShuffleResult ();
167
+ laneVal = vector::makeArithReduction (builder, loc,
168
+ gpu::convertReductionKind (mode),
169
+ laneVal, unpackFn (shuffled));
170
+ assert (laneVal.getType () == input.getType ());
171
+ }
172
+
173
+ return laneVal;
174
+ }
175
+
176
+ // / Lowers scalar gpu subgroup reductions to a series of shuffles.
177
+ struct ScalarSubgroupReduceToShuffles final
178
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
179
+ ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
180
+ unsigned shuffleBitwidth,
181
+ PatternBenefit benefit)
182
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
183
+ shuffleBitwidth (shuffleBitwidth) {}
184
+
185
+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
186
+ PatternRewriter &rewriter) const override {
187
+ Type valueTy = op.getType ();
188
+ unsigned elemBitwidth =
189
+ getElementTypeOrSelf (valueTy).getIntOrFloatBitWidth ();
190
+ if (!valueTy.isIntOrFloat () || elemBitwidth > shuffleBitwidth)
191
+ return rewriter.notifyMatchFailure (
192
+ op, " value type is not a compatible scalar" );
193
+
194
+ Location loc = op.getLoc ();
195
+ // Since this is already a native shuffle scalar, no packing is necessary.
196
+ if (elemBitwidth == shuffleBitwidth) {
197
+ auto identityFn = [](Value v) { return v; };
198
+ rewriter.replaceOp (op, createSubgroupShuffleReduction (
199
+ rewriter, loc, op.getValue (), op.getOp (),
200
+ subgroupSize, identityFn, identityFn));
201
+ return success ();
202
+ }
203
+
204
+ auto shuffleIntType = rewriter.getIntegerType (shuffleBitwidth);
205
+ auto equivIntType = rewriter.getIntegerType (elemBitwidth);
206
+ auto packFn = [loc, &rewriter, equivIntType,
207
+ shuffleIntType](Value unpackedVal) -> Value {
208
+ auto asInt =
209
+ rewriter.create <arith::BitcastOp>(loc, equivIntType, unpackedVal);
210
+ return rewriter.create <arith::ExtUIOp>(loc, shuffleIntType, asInt);
211
+ };
212
+ auto unpackFn = [loc, &rewriter, equivIntType,
213
+ valueTy](Value packedVal) -> Value {
214
+ auto asInt =
215
+ rewriter.create <arith::TruncIOp>(loc, equivIntType, packedVal);
216
+ return rewriter.create <arith::BitcastOp>(loc, valueTy, asInt);
217
+ };
218
+
219
+ rewriter.replaceOp (op, createSubgroupShuffleReduction (
220
+ rewriter, loc, op.getValue (), op.getOp (),
221
+ subgroupSize, packFn, unpackFn));
222
+ return success ();
223
+ }
224
+
225
+ private:
226
+ unsigned subgroupSize = 0 ;
227
+ unsigned shuffleBitwidth = 0 ;
228
+ };
229
+
230
+ // / Lowers vector gpu subgroup reductions to a series of shuffles.
231
+ struct VectorSubgroupReduceToShuffles final
232
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
233
+ VectorSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
234
+ unsigned shuffleBitwidth,
235
+ PatternBenefit benefit)
236
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
237
+ shuffleBitwidth (shuffleBitwidth) {}
238
+
239
+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
240
+ PatternRewriter &rewriter) const override {
241
+ auto vecTy = dyn_cast<VectorType>(op.getType ());
242
+ if (!vecTy)
243
+ return rewriter.notifyMatchFailure (op, " value type is not a vector" );
244
+
245
+ unsigned vecBitwidth =
246
+ vecTy.getNumElements () * vecTy.getElementTypeBitWidth ();
247
+ if (vecBitwidth > shuffleBitwidth)
248
+ return rewriter.notifyMatchFailure (
249
+ op,
250
+ llvm::formatv (" vector type bitwidth too large ({0}), cannot lower "
251
+ " to shuffles of size {1}" ,
252
+ vecBitwidth, shuffleBitwidth));
253
+
254
+ unsigned elementsPerShuffle =
255
+ shuffleBitwidth / vecTy.getElementTypeBitWidth ();
256
+ if (elementsPerShuffle * vecTy.getElementTypeBitWidth () != shuffleBitwidth)
257
+ return rewriter.notifyMatchFailure (
258
+ op, " shuffle bitwidth is not a multiple of the element bitwidth" );
259
+
260
+ Location loc = op.getLoc ();
261
+
262
+ // If the reduced type is smaller than the native shuffle size, extend it,
263
+ // perform the shuffles, and extract at the end.
264
+ auto extendedVecTy = VectorType::get (
265
+ static_cast <int64_t >(elementsPerShuffle), vecTy.getElementType ());
266
+ Value extendedInput = op.getValue ();
267
+ if (vecBitwidth < shuffleBitwidth) {
268
+ auto zero = rewriter.create <arith::ConstantOp>(
269
+ loc, rewriter.getZeroAttr (extendedVecTy));
270
+ extendedInput = rewriter.create <vector::InsertStridedSliceOp>(
271
+ loc, extendedInput, zero, /* offsets=*/ 0 , /* strides=*/ 1 );
272
+ }
273
+
274
+ auto shuffleIntType = rewriter.getIntegerType (shuffleBitwidth);
275
+ auto shuffleVecType = VectorType::get (1 , shuffleIntType);
276
+
277
+ auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
278
+ auto asIntVec =
279
+ rewriter.create <vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
280
+ return rewriter.create <vector::ExtractOp>(loc, asIntVec, 0 );
281
+ };
282
+ auto unpackFn = [loc, &rewriter, shuffleVecType,
283
+ extendedVecTy](Value packedVal) -> Value {
284
+ auto asIntVec =
285
+ rewriter.create <vector::BroadcastOp>(loc, shuffleVecType, packedVal);
286
+ return rewriter.create <vector::BitCastOp>(loc, extendedVecTy, asIntVec);
287
+ };
288
+
289
+ Value res =
290
+ createSubgroupShuffleReduction (rewriter, loc, extendedInput, op.getOp (),
291
+ subgroupSize, packFn, unpackFn);
292
+
293
+ if (vecBitwidth < shuffleBitwidth) {
294
+ res = rewriter.create <vector::ExtractStridedSliceOp>(
295
+ loc, res, /* offsets=*/ 0 , /* sizes=*/ vecTy.getNumElements (),
296
+ /* strides=*/ 1 );
297
+ }
298
+
299
+ rewriter.replaceOp (op, res);
300
+ return success ();
301
+ }
302
+
303
+ private:
304
+ unsigned subgroupSize = 0 ;
305
+ unsigned shuffleBitwidth = 0 ;
306
+ };
142
307
} // namespace
143
308
144
309
void mlir::populateGpuBreakDownSubgrupReducePatterns (
@@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
148
313
maxShuffleBitwidth, benefit);
149
314
patterns.add <ScalarizeSingleElementReduce>(patterns.getContext (), benefit);
150
315
}
316
+
317
+ void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs (
318
+ RewritePatternSet &patterns, unsigned subgroupSize,
319
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
320
+ patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
321
+ patterns.getContext (), subgroupSize, shuffleBitwidth, benefit);
322
+ }
0 commit comments