Skip to content

Commit c0345b4

Browse files
authored
[mlir][gpu] Add subgroup_reduce to shuffle lowering (llvm#76530)
This supports both the scalar and the vector multi-reduction cases.
1 parent 8c7dfaf commit c0345b4

File tree

8 files changed

+408
-64
lines changed

8 files changed

+408
-64
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
7070
unsigned maxShuffleBitwidth = 32,
7171
PatternBenefit benefit = 1);
7272

73+
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
74+
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
75+
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
76+
void populateGpuLowerSubgroupReduceToShufflePattenrs(
77+
RewritePatternSet &patterns, unsigned subgroupSize,
78+
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
79+
7380
/// Collect all patterns to rewrite ops within the GPU dialect.
7481
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
7582
populateGpuAllReducePatterns(patterns);

mlir/include/mlir/Dialect/GPU/Transforms/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#ifndef MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
1414
#define MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
1515

16+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1618
#include "mlir/Support/LLVM.h"
1719

1820
#include <string>
@@ -28,6 +30,9 @@ class LaunchOp;
2830

2931
/// Returns the default annotation name for GPU binary blobs.
3032
std::string getDefaultGpuBinaryAnnotation();
33+
34+
/// Returns the matching vector combining kind.
35+
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
3136
} // namespace gpu
3237

3338
/// Get a gpu.func created from outlining the region of a gpu.launch op with the

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
6464
Transforms/ShuffleRewriter.cpp
6565
Transforms/SPIRVAttachTarget.cpp
6666
Transforms/SubgroupReduceLowering.cpp
67+
Transforms/Utils.cpp
6768

6869
ADDITIONAL_HEADER_DIRS
6970
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,6 @@ using namespace mlir;
2727

2828
namespace {
2929

30-
static vector::CombiningKind
31-
convertReductionKind(gpu::AllReduceOperation mode) {
32-
switch (mode) {
33-
#define MAP_CASE(X) \
34-
case gpu::AllReduceOperation::X: \
35-
return vector::CombiningKind::X
36-
37-
MAP_CASE(ADD);
38-
MAP_CASE(MUL);
39-
MAP_CASE(MINUI);
40-
MAP_CASE(MINSI);
41-
MAP_CASE(MINNUMF);
42-
MAP_CASE(MAXSI);
43-
MAP_CASE(MAXUI);
44-
MAP_CASE(MAXNUMF);
45-
MAP_CASE(AND);
46-
MAP_CASE(OR);
47-
MAP_CASE(XOR);
48-
MAP_CASE(MINIMUMF);
49-
MAP_CASE(MAXIMUMF);
50-
51-
#undef MAP_CASE
52-
}
53-
54-
llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
55-
}
56-
5730
struct GpuAllReduceRewriter {
5831
using AccumulatorFactory = std::function<Value(Value, Value)>;
5932

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1515
#include "mlir/Dialect/GPU/Transforms/Passes.h"
16+
#include "mlir/Dialect/GPU/Transforms/Utils.h"
1617
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/IR/BuiltinTypes.h"
1719
#include "mlir/IR/Location.h"
1820
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeUtilities.h"
1922
#include "mlir/Support/LogicalResult.h"
2023
#include "llvm/Support/FormatVariadic.h"
2124
#include "llvm/Support/MathExtras.h"
2225
#include <cassert>
26+
#include <cstdint>
2327

2428
using namespace mlir;
2529

@@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
5862
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
5963
if (elemBitwidth >= maxShuffleBitwidth)
6064
return rewriter.notifyMatchFailure(
61-
op, llvm::formatv("element type too large {0}, cannot break down "
65+
op, llvm::formatv("element type too large ({0}), cannot break down "
6266
"into vectors of bitwidth {1} or less",
6367
elemBitwidth, maxShuffleBitwidth));
6468

@@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
139143
}
140144
};
141145

146+
/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
147+
/// and `unpackFn` to convert to the native shuffle type and to the reduction
148+
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
149+
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
150+
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
151+
/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
152+
static Value createSubgroupShuffleReduction(
153+
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
154+
unsigned subgroupSize, function_ref<Value(Value)> packFn,
155+
function_ref<Value(Value)> unpackFn) {
156+
assert(llvm::isPowerOf2_32(subgroupSize));
157+
// Lane value always stays in the original type. We use it to perform arith
158+
// reductions.
159+
Value laneVal = input;
160+
// Parallel reduction using butterfly shuffles.
161+
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
162+
Value shuffled = builder
163+
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
164+
/*width=*/subgroupSize,
165+
/*mode=*/gpu::ShuffleMode::XOR)
166+
.getShuffleResult();
167+
laneVal = vector::makeArithReduction(builder, loc,
168+
gpu::convertReductionKind(mode),
169+
laneVal, unpackFn(shuffled));
170+
assert(laneVal.getType() == input.getType());
171+
}
172+
173+
return laneVal;
174+
}
175+
176+
/// Lowers scalar gpu subgroup reductions to a series of shuffles.
177+
struct ScalarSubgroupReduceToShuffles final
178+
: OpRewritePattern<gpu::SubgroupReduceOp> {
179+
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
180+
unsigned shuffleBitwidth,
181+
PatternBenefit benefit)
182+
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
183+
shuffleBitwidth(shuffleBitwidth) {}
184+
185+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
186+
PatternRewriter &rewriter) const override {
187+
Type valueTy = op.getType();
188+
unsigned elemBitwidth =
189+
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
190+
if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
191+
return rewriter.notifyMatchFailure(
192+
op, "value type is not a compatible scalar");
193+
194+
Location loc = op.getLoc();
195+
// Since this is already a native shuffle scalar, no packing is necessary.
196+
if (elemBitwidth == shuffleBitwidth) {
197+
auto identityFn = [](Value v) { return v; };
198+
rewriter.replaceOp(op, createSubgroupShuffleReduction(
199+
rewriter, loc, op.getValue(), op.getOp(),
200+
subgroupSize, identityFn, identityFn));
201+
return success();
202+
}
203+
204+
auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
205+
auto equivIntType = rewriter.getIntegerType(elemBitwidth);
206+
auto packFn = [loc, &rewriter, equivIntType,
207+
shuffleIntType](Value unpackedVal) -> Value {
208+
auto asInt =
209+
rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
210+
return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
211+
};
212+
auto unpackFn = [loc, &rewriter, equivIntType,
213+
valueTy](Value packedVal) -> Value {
214+
auto asInt =
215+
rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
216+
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
217+
};
218+
219+
rewriter.replaceOp(op, createSubgroupShuffleReduction(
220+
rewriter, loc, op.getValue(), op.getOp(),
221+
subgroupSize, packFn, unpackFn));
222+
return success();
223+
}
224+
225+
private:
226+
unsigned subgroupSize = 0;
227+
unsigned shuffleBitwidth = 0;
228+
};
229+
230+
/// Lowers vector gpu subgroup reductions to a series of shuffles.
231+
struct VectorSubgroupReduceToShuffles final
232+
: OpRewritePattern<gpu::SubgroupReduceOp> {
233+
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
234+
unsigned shuffleBitwidth,
235+
PatternBenefit benefit)
236+
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
237+
shuffleBitwidth(shuffleBitwidth) {}
238+
239+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
240+
PatternRewriter &rewriter) const override {
241+
auto vecTy = dyn_cast<VectorType>(op.getType());
242+
if (!vecTy)
243+
return rewriter.notifyMatchFailure(op, "value type is not a vector");
244+
245+
unsigned vecBitwidth =
246+
vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
247+
if (vecBitwidth > shuffleBitwidth)
248+
return rewriter.notifyMatchFailure(
249+
op,
250+
llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
251+
"to shuffles of size {1}",
252+
vecBitwidth, shuffleBitwidth));
253+
254+
unsigned elementsPerShuffle =
255+
shuffleBitwidth / vecTy.getElementTypeBitWidth();
256+
if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
257+
return rewriter.notifyMatchFailure(
258+
op, "shuffle bitwidth is not a multiple of the element bitwidth");
259+
260+
Location loc = op.getLoc();
261+
262+
// If the reduced type is smaller than the native shuffle size, extend it,
263+
// perform the shuffles, and extract at the end.
264+
auto extendedVecTy = VectorType::get(
265+
static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
266+
Value extendedInput = op.getValue();
267+
if (vecBitwidth < shuffleBitwidth) {
268+
auto zero = rewriter.create<arith::ConstantOp>(
269+
loc, rewriter.getZeroAttr(extendedVecTy));
270+
extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
271+
loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
272+
}
273+
274+
auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
275+
auto shuffleVecType = VectorType::get(1, shuffleIntType);
276+
277+
auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
278+
auto asIntVec =
279+
rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
280+
return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
281+
};
282+
auto unpackFn = [loc, &rewriter, shuffleVecType,
283+
extendedVecTy](Value packedVal) -> Value {
284+
auto asIntVec =
285+
rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
286+
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
287+
};
288+
289+
Value res =
290+
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
291+
subgroupSize, packFn, unpackFn);
292+
293+
if (vecBitwidth < shuffleBitwidth) {
294+
res = rewriter.create<vector::ExtractStridedSliceOp>(
295+
loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
296+
/*strides=*/1);
297+
}
298+
299+
rewriter.replaceOp(op, res);
300+
return success();
301+
}
302+
303+
private:
304+
unsigned subgroupSize = 0;
305+
unsigned shuffleBitwidth = 0;
306+
};
142307
} // namespace
143308

144309
void mlir::populateGpuBreakDownSubgrupReducePatterns(
@@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
148313
maxShuffleBitwidth, benefit);
149314
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
150315
}
316+
317+
void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
318+
RewritePatternSet &patterns, unsigned subgroupSize,
319+
unsigned shuffleBitwidth, PatternBenefit benefit) {
320+
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
321+
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
322+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- Utils.cpp - GPU transforms utils -----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implements GPU dialect transforms utils.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/GPU/Transforms/Utils.h"
14+
#include "llvm/Support/ErrorHandling.h"
15+
16+
namespace mlir::gpu {
17+
18+
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
19+
switch (mode) {
20+
#define MAP_CASE(X) \
21+
case gpu::AllReduceOperation::X: \
22+
return vector::CombiningKind::X
23+
24+
MAP_CASE(ADD);
25+
MAP_CASE(MUL);
26+
MAP_CASE(MINUI);
27+
MAP_CASE(MINSI);
28+
MAP_CASE(MINNUMF);
29+
MAP_CASE(MAXSI);
30+
MAP_CASE(MAXUI);
31+
MAP_CASE(MAXNUMF);
32+
MAP_CASE(AND);
33+
MAP_CASE(OR);
34+
MAP_CASE(XOR);
35+
MAP_CASE(MINIMUMF);
36+
MAP_CASE(MAXIMUMF);
37+
38+
#undef MAP_CASE
39+
}
40+
41+
llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
42+
}
43+
44+
} // namespace mlir::gpu

0 commit comments

Comments
 (0)