diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 43dbb31da..83c46665d 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -120,6 +120,36 @@ def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> { "Call finish() after each kernel launch."> ]; } + +def GpuTilingAndFusion : Pass<"gpu-tiling", "func::FuncOp"> { + let summary = "GPU tiling and fusion path."; + let description = [{ + This pass tiles linalg operations and creates two nested scf.forall loops. When converting to gpu.launch, + the inner loop is mapped to the block sizes and the outer - to grid sizes. The tiles calculation is based + on the GPU device properties, retrieved from the DLTI attributes. If the DLTI attributes are not specified, + defaults to the pass options. + }]; + let options = [ + Option<"numEus", "num-eus", "size_t", + /*default=*/"448", + "Number of Execution Units.">, + Option<"numEusPerSlice", "num-eus-per-slice", "size_t", + /*default=*/"8", + "Number of Execution Units per slice.">, + Option<"numThreadsPerEu", "num-threads-per-eu", "size_t", + /*default=*/"8", + "Number of threads per Execution Unit.">, + Option<"localMemSize", "local-mem-size", "size_t", + /*default=*/"131072", + "The size of the local memory, shared across a work-group.">, + Option<"vectorWidth", "vector-width", "size_t", + /*default=*/"512", + "The maximum width of EU's vector registers.">, + Option<"workGroupSize", "work-group-size", "size_t", + /*default=*/"64", + "The maximum workgroup size."> + ]; +} #endif // GC_USE_IMEX def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", diff --git a/lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp b/lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp index 97fe1760c..6798cd41a 100644 --- a/lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp +++ b/lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp @@ -876,8 +876,7 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) { {CL_DEVICE_MAX_COMPUTE_UNITS, "num_exec_units"}, {CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, "num_exec_units_per_slice"}, {CL_DEVICE_NUM_THREADS_PER_EU_INTEL, "num_threads_per_eu"}, - // Assuming the cache size is equal to the local mem - {CL_DEVICE_LOCAL_MEM_SIZE, "L1_cache_size_in_bytes"}, + {CL_DEVICE_LOCAL_MEM_SIZE, "local_mem_size"}, }; unsigned i = 0; diff --git a/lib/gc/Transforms/GPU/CMakeLists.txt b/lib/gc/Transforms/GPU/CMakeLists.txt index 28981f200..f4b286b94 100644 --- a/lib/gc/Transforms/GPU/CMakeLists.txt +++ b/lib/gc/Transforms/GPU/CMakeLists.txt @@ -13,6 +13,7 @@ set_property(GLOBAL APPEND PROPERTY IMEX_LIBS ${IMEX_LIBS}) gc_add_mlir_library(GcGpuPasses AddContextArg.cpp AllocsToSLM.cpp + GpuTilingAndFusion.cpp GpuToGpuOcl.cpp LinalgToXeGPU.cpp Pipeline.cpp diff --git a/lib/gc/Transforms/GPU/GpuTilingAndFusion.cpp b/lib/gc/Transforms/GPU/GpuTilingAndFusion.cpp new file mode 100644 index 000000000..45299fa48 --- /dev/null +++ b/lib/gc/Transforms/GPU/GpuTilingAndFusion.cpp @@ -0,0 +1,545 @@ +//===-- GpuTilingAndFusion.cpp - DESC ---------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "./GpuUtils.h" +#include "gc/Dialect/Linalgx/Utils.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::gc; +using namespace mlir::scf; + +namespace mlir::gc { +#define GEN_PASS_DECL_GPUTILINGANDFUSION +#define GEN_PASS_DEF_GPUTILINGANDFUSION +#include "gc/Transforms/Passes.h.inc" +} // namespace mlir::gc + +namespace { + +struct GpuTilingAndFusion final + : GpuPass, + gc::impl::GpuTilingAndFusionBase { + friend struct GpuPass; + explicit GpuTilingAndFusion() + : GpuTilingAndFusion(GpuTilingAndFusionOptions{}) {} + explicit GpuTilingAndFusion(const GpuTilingAndFusionOptions &opts) + : GpuPass(), GpuTilingAndFusionBase(opts) {} + + void runOnOperation() override { + auto fn = getOperation(); + if (fn.isExternal()) { + return; + } + + OpRewriter rw(fn); + auto loopMarker = rw.getStringAttr("gcGpuLoop"); + tileAndFuseLinalgOps(rw, fn, loopMarker); + tileForallOps(rw, fn, loopMarker); + } + +private: + void tileAndFuseLinalgOps(OpRewriter &rw, func::FuncOp &fn, + StringAttr &loopMarker) { + auto markerValue = rw.getBoolAttr(true); + auto numEus = getNumEus(rw); + auto numEusPerSlice = getNumEusPerSlice(rw); + auto numThreadsPerEu = getNumThreadsPerEu(rw); + auto localMemSize = getLocalMemSize(rw); + auto vectorWidth = getVectorWidth(rw); + auto cachePerThread = + std::max(localMemSize / numEusPerSlice / numThreadsPerEu, vectorWidth); + SCFTileAndFuseOptions opts; + opts.tilingOptions.setTileSizeComputationFunction( + [&rw, cachePerThread, vectorWidth, + numThreads = numEus * numThreadsPerEu]( + OpBuilder &builder, Operation *op) -> SmallVector { + auto ti = dyn_cast(op); + if (!ti) { + return {}; + } + + rw.loc = op->getLoc(); + rw.setInsertionPoint(op); + auto itTypes = ti.getLoopIteratorTypes(); + auto itDomains = ti.getIterationDomain(builder); + assert(itTypes.size() == itDomains.size()); + + SmallVector sizes; + int64_t maxSize = 0; + int64_t numIterations = 1; + for (auto [t, r] : zip(itTypes, itDomains)) { + if (t == utils::IteratorType::parallel) { + if (auto v = getConstantIntValue(r.size)) { + numIterations *= *v; + sizes.emplace_back(*v); + maxSize = std::max(maxSize, *v); + } else { + return computeDynamicTiles(rw, ti, numThreads, cachePerThread); + } + } + } + + assert(!sizes.empty()); + auto elementSize = getElementSize(op); + auto sizePerThread = numIterations / numThreads * elementSize; + auto totalSize = std::max(sizePerThread, cachePerThread); + totalSize = std::max(totalSize / elementSize, 64L); + int64_t minTileSize = 1; + + // If the operation could be lowered to XeGPU, make the tiles + // multiple of the vector width and the minimum tile size 8. + if (canLowerToXeGPU(op)) { + minTileSize = 8; + totalSize = std::max(totalSize / vectorWidth, 1L) * vectorWidth; + } + + SmallVector tiles = sizes; + adjustTiles(totalSize, tiles, minTileSize); + + // If the tiles are equal to the sizes, split the largest tile + // to avoid loops elimination by the canonicalizer pass. + if (tiles == sizes) { + auto tile = findFactor(maxSize, maxSize / 2); + + if (tile == maxSize) { + // Find another size, that can be split + auto another = maxSize; + sort(sizes, std::greater<>()); + for (auto s : sizes) { + if (s != maxSize && (tile = findFactor(s, s / 2)) != s) { + another = s; + break; + } + } + if (another == maxSize) { + tile = 1; + // Find the smallest size that is not 1 + for (auto s : reverse(sizes)) { + if (s != 1) { + maxSize = s; + break; + } + } + } else { + maxSize = another; + } + } + + for (auto &t : tiles) { + if (t == maxSize) { + t = tile; + break; + } + } + } + + unsigned counter = 0; + SmallVector result; + result.reserve(itDomains.size()); + + for (auto [t, r] : zip(itTypes, itDomains)) { + if (t == utils::IteratorType::parallel) { + result.emplace_back(rw.createConstant(tiles[counter++])); + } else { + result.emplace_back(rw.createConstant(0L)); + } + } + + return result; + }); + opts.setFusionControlFn( + [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, + bool) -> std::optional { + Operation *op = originalProducer.getOwner(); + if (!op) { + return std::nullopt; + } + if (auto linalgOp = dyn_cast(op)) { + if (!linalgOp.hasOnlyProjectedPermutations()) { + return std::nullopt; + } + } + + // If the result of this slice is used by a MatmulOp and the slice has + // an operand produced by a previous MatmulOp, do not fuse. + if (isOpDependsOnResult<0>(linalgx::isMatmulOp, candidateSliceOp) && + isOperandDependsOnOp(linalgx::isMatmulOp, candidateSliceOp)) { + return std::nullopt; + } + + return SCFTileAndFuseOptions::ControlFnResult{}; + }); + opts.tilingOptions.setLoopType(SCFTilingOptions::LoopType::ForallOp); + + for (auto ti = findTi(rw, fn, loopMarker); ti; + ti = findTi(rw, fn, loopMarker)) { + auto result = tileConsumerAndFuseProducersUsingSCF(rw, *ti, opts); + + if (failed(result)) { + ti->emitError() << "Failed to tile and fuse using SCF"; + return; + } + + SmallVector opsToReplace{ti->getOperation()}; + append_range(opsToReplace, result->fusedProducers); + for (Operation *toReplace : opsToReplace) { + for (OpResult res : toReplace->getResults()) { + if (auto repl = result->replacements.lookup(res)) { + rw.replaceAllUsesWith(res, repl); + if (auto loop = dyn_cast(repl.getDefiningOp())) { + loop->setAttr(loopMarker, markerValue); + } + } + } + } + + if (failed(simplifyRegions(rw, fn->getRegions()))) { + // Not simplified + } + } + } + + static std::optional findTi(OpBuilder &b, Operation *op, + const StringAttr &loopMarker) { + std::optional last; + op->walk([&](linalg::LinalgOp linalgOp) { + if (!linalgOp.hasOnlyProjectedPermutations()) { + return WalkResult::skip(); + } + if (auto parentLoop = linalgOp->getParentOfType(); + parentLoop && parentLoop->hasAttr(loopMarker)) { + return WalkResult::skip(); + } + + if (auto ti = dyn_cast(linalgOp.getOperation())) { + int64_t numTiles = 0; + int64_t numIterations = 1; + for (auto [t, r] : + zip(ti.getLoopIteratorTypes(), ti.getIterationDomain(b))) { + if (t == utils::IteratorType::parallel) { + numTiles++; + if (auto v = getConstantIntValue(r.size)) { + numIterations *= *v; + } + } + } + if (numTiles > 0 && numIterations >= 32) { + last = ti; + } + } + + return WalkResult::skip(); + }); + return last; + } + + static SmallVector computeDynamicTiles(OpRewriter &rw, + TilingInterface ti, + int64_t numThreads, + int64_t cachePerThread) { + auto itTypes = ti.getLoopIteratorTypes(); + auto itDomains = ti.getIterationDomain(rw); + assert(itTypes.size() == itDomains.size()); + rw.loc = ti.getLoc(); + rw.setInsertionPoint(ti.getOperation()); + + Value dynamicSize; + auto staticSize = getElementSize(ti.getOperation()); + unsigned loopCount = 0; + + for (auto [t, r] : zip(itTypes, itDomains)) { + if (t != utils::IteratorType::parallel) { + continue; + } + loopCount++; + if (auto v = getConstantIntValue(r.size)) { + staticSize *= *v; + } else if (dynamicSize) { + dynamicSize = + rw.create(dynamicSize, r.size.get()); + } else { + dynamicSize = r.size.get(); + } + } + + assert(loopCount); + assert(dynamicSize); + if (staticSize > 1) { + dynamicSize = + rw.create(dynamicSize, rw.createConstant(staticSize)); + } + auto i64Type = rw.getI64Type(); + dynamicSize = rw.create( + rw.getF64Type(), rw.create(i64Type, dynamicSize)); + + // TODO: Call the adjustTiles() function for the tiles calculation. + + auto nt = rw.createConstant(static_cast(numThreads)); + auto cpt = rw.createConstant(static_cast(cachePerThread)); + Value totalSize = rw.create( + rw.getF64Type(), rw.create(dynamicSize, nt), cpt); + auto pow = rw.createConstant(1.0 / loopCount); + // The average tile size is totalSize^(1 / loopCount) + Value avgTileSize = rw.create(totalSize, pow); + avgTileSize = rw.create( + rw.getF64Type(), rw.createConstant(1.0), avgTileSize); + avgTileSize = rw.create(i64Type, avgTileSize); + + SmallVector tiles; + tiles.reserve(itDomains.size()); + + for (auto [t, r] : zip(itTypes, itDomains)) { + if (t != utils::IteratorType::parallel) { + tiles.emplace_back(rw.getIndexAttr(1)); + } else { + Value value; + if (auto v = getConstantIntValue(r.size)) { + value = rw.create(*v, i64Type); + } else { + value = rw.create(i64Type, r.size.get()); + } + value = rw.create(i64Type, value, avgTileSize); + tiles.emplace_back( + rw.create(rw.getIndexType(), value)); + } + } + + return tiles; + } + + static int64_t getElementSize(Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (auto inits = linalgOp.getDpsInits(); !inits.empty()) { + if (auto t = getElementTypeOrSelf(inits[0].getType()); + t.isIntOrFloat()) { + return std::max(1L, t.getIntOrFloatBitWidth() / 8L); + } + } + } + return 1L; + } + + // TODO: Add more checks + static bool canLowerToXeGPU(Operation *operation) { + auto op = dyn_cast(operation); + if (!op) { + return false; + } + if (op.hasDynamicShape()) { + return false; + } + + auto checkOperand = [&](Value operand, bool isOutput = false) { + ShapedType type; + if (auto memref = dyn_cast(operand.getType())) { + type = memref; + } else if (auto tensor = dyn_cast(operand.getType())) { + type = tensor; + } else { + return false; + } + + auto shape = type.getShape(); + if (isOutput) { + if (shape.size() != 2 || shape[0] * shape[1] < 16) { + return false; + } + } else if (shape.size() > 2) { + return false; + } + + return true; + }; + + if (auto inits = op.getDpsInits(); + !inits.empty() && !checkOperand(inits[0], true)) { + return false; + } + + if (auto inputs = op.getDpsInputs(); + !std::all_of(inputs.begin(), inputs.end(), + [&](Value v) { return checkOperand(v); })) { + return false; + } + + return true; + } + + void tileForallOps(OpRewriter &rw, func::FuncOp &fn, StringAttr &loopMarker) { + auto wgSize = getWorkGroupSize(rw); + fn.walk([&rw, wgSize, loopMarker](ForallOp loop) { + if (loop->removeAttr(loopMarker)) { + replaceEmptySlices(rw, loop); + + // If there is only one user, and it's located in a different block, + // and this block is not inside a loop, move the loop to the user block. + if (loop->hasOneUse()) { + auto user = *loop->getUsers().begin(); + if (user->getBlock() != loop->getBlock()) { + if (!user->getParentOfType()) { + loop->moveBefore(user); + } + } + } + + tileForallOp(rw, loop, wgSize); + } + return WalkResult::skip(); + }); + } + + // If a slice inside the loop is created from an external empty tensor and the + // tensor is not passed to the loop's shared_outs, but referenced directly, + // replace the slice with an empty tensor of the same size. + static void replaceEmptySlices(OpRewriter &rw, ForallOp loop) { + loop.walk([&](tensor::ExtractSliceOp slice) { + if (auto empty = slice.getSource().getDefiningOp(); + empty && empty->getParentOfType() != loop) { + auto type = slice.getType(); + rw.setInsertionPointAfter(slice); + SmallVector dynDims; + for (int64_t i = 0, r = type.getRank(); i < r; ++i) { + if (type.isDynamicDim(i)) { + dynDims.push_back(rw.create(slice, i)); + } + } + rw.replaceOp(slice, rw.create(type.getShape(), + type.getElementType(), + dynDims)); + } + }); + } + + static void tileForallOp(OpRewriter &rw, ForallOp op, int64_t wgSize) { + rw.loc = op.getLoc(); + rw.setInsertionPoint(op); + OpFoldResult zero{rw.createConstant(0L)}; + OpFoldResult one{rw.createConstant(1L)}; + auto innerSteps = op.getMixedStep(); + auto outerBounds = op.getMixedUpperBound(); + SmallVector outerSteps; + auto count = innerSteps.size(); + + { // Calculate outer steps + SmallVector tiles; + tiles.reserve(count); + for (auto s : innerSteps) { + if (auto v = getConstantIntValue(s)) { + tiles.emplace_back(*v); + } else { + // TODO: Add support for dynamic sizes + tiles.emplace_back(32); + } + } + adjustTiles(wgSize, tiles); + outerSteps.reserve(count); + for (auto [s, b, t] : zip(innerSteps, outerBounds, tiles)) { + if (auto sv = getConstantIntValue(s)) { + auto step = *sv * t; + if (auto bv = getConstantIntValue(b)) { + step = std::min(step, *bv); + } + outerSteps.emplace_back(rw.createConstant(step)); + } else { + outerSteps.emplace_back( + rw.create(s.get(), rw.createConstant(t))); + } + } + } + + auto outerLoop = + rw.create(op.getMixedLowerBound(), outerBounds, outerSteps, + op.getOutputs(), std::nullopt); + rw.setInsertionPointToStart(outerLoop.getBody()); + SmallVector innerBounds; + SmallVector ranges; + + { + auto idxType = rw.getIndexType(); + auto ctx = rw.getContext(); + auto minMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {getAffineDimExpr(0, ctx), + getAffineDimExpr(1, ctx) - getAffineDimExpr(2, ctx)}, + rw.getContext()); + innerBounds.reserve(count); + ranges.reserve(count); + for (auto [i, u, s] : zip(outerLoop.getInductionVars(), + outerLoop.getMixedUpperBound(), outerSteps)) { + OpFoldResult iub; + auto cu = getConstantIntValue(u); + auto cs = getConstantIntValue(s); + if (cu && cs && (*cu % *cs == 0)) { + iub = s; + } else { + Value vub = cu ? rw.createConstant(*cu) : u.get(); + Value vs = cs ? rw.createConstant(*cs) : s.get(); + iub = OpFoldResult(rw.create( + idxType, minMap, ValueRange{vs, vub, i})); + } + innerBounds.emplace_back(iub); + ranges.emplace_back(Range{i, iub, one}); + } + } + + SmallVector innerOutputs; + for (auto o : outerLoop.getRegionIterArgs()) { + innerOutputs.emplace_back(rw.create(o, ranges)); + } + + auto innerLoop = + rw.create(SmallVector(count, zero), innerBounds, innerSteps, + innerOutputs, op.getMapping()); + SmallVector argTypes{innerLoop.getBody()->getArgumentTypes()}; + innerLoop.getRegion().takeBody(op.getRegion()); + for (auto [arg, type] : + zip(innerLoop.getBody()->getArguments(), argTypes)) { + arg.setType(type); + } + + // Collect all users of the inner loop outputs + llvm::SmallSet outUsers; + for (auto out : innerLoop.getRegionIterArgs()) { + for (auto user : out.getUsers()) { + outUsers.insert(user); + } + } + + // Replace the induction variables of the inner loop with the sum of the + // outer and inner induction variables, but only in the operations, that + // are not using the inner loop outputs, which are already sliced. + rw.setInsertionPointToStart(innerLoop.getBody()); + for (auto [inIdx, outIdx] : + zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { + auto newIdx = rw.create(inIdx, outIdx); + outUsers.insert(newIdx); + inIdx.replaceAllUsesExcept(newIdx, outUsers); + } + + rw.setInsertionPointToStart(outerLoop.getTerminator().getBody()); + for (auto [i, o] : + zip(innerLoop.getResults(), outerLoop.getRegionIterArgs())) { + rw.create(i, o, ranges); + } + + rw.replaceOp(op, outerLoop); + } +}; +} // namespace diff --git a/lib/gc/Transforms/GPU/GpuUtils.h b/lib/gc/Transforms/GPU/GpuUtils.h new file mode 100644 index 000000000..2bc079978 --- /dev/null +++ b/lib/gc/Transforms/GPU/GpuUtils.h @@ -0,0 +1,311 @@ +//===-- GpuUtils.h - DESC ---------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef GPUUTILS_H +#define GPUUTILS_H + +#include "gc/Utils/Log.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +#include +#include + +using namespace mlir; + +namespace mlir::gc { +template struct GpuPass { + + int64_t getGpuPropertyAsInt(Builder &builder, StringRef name, + int64_t defaultValue) { + if (auto mod = static_cast(this) + ->getOperation() + ->template getParentOfType()) { + DataLayout layout(mod); + if (auto value = layout.getDevicePropertyValue( + builder.getStringAttr("GPU" /* device ID*/), + builder.getStringAttr(name))) { + if (auto attr = dyn_cast(*value)) { + return attr.getInt(); + } + } + } + return defaultValue; + } + + int64_t getNumEus(Builder &builder) { + return getGpuPropertyAsInt(builder, "num_exec_units", + static_cast(this)->numEus); + } + + int64_t getNumEusPerSlice(Builder &builder) { + return getGpuPropertyAsInt(builder, "num_exec_units_per_slice", + static_cast(this)->numEusPerSlice); + } + + int64_t getNumThreadsPerEu(Builder &builder) { + return getGpuPropertyAsInt(builder, "num_threads_per_eu", + static_cast(this)->numThreadsPerEu); + } + + int64_t getLocalMemSize(Builder &builder) { + return getGpuPropertyAsInt(builder, "local_mem_size", + static_cast(this)->localMemSize); + } + + int64_t getVectorWidth(Builder &builder) { + return getGpuPropertyAsInt(builder, "max_vector_op_width", + static_cast(this)->vectorWidth); + } + + int64_t getWorkGroupSize(Builder &builder) { + return getGpuPropertyAsInt(builder, "max_work_group_size", + static_cast(this)->workGroupSize); + } +}; + +// This class is a placeholder for the rewriter-related boilerplate code. +struct OpRewriter final : IRRewriter { + Location loc; + + explicit OpRewriter(func::FuncOp &func) + : IRRewriter(func.getContext()), loc(func.getLoc()) {} + + template OpTy create(Args &&...args) { + return RewriterBase::create(loc, std::forward(args)...); + } + + arith::ConstantIndexOp createConstant(int64_t v) { + return create(v); + } + + arith::ConstantFloatOp createConstant(double v) { + return create(APFloat(v), getF64Type()); + } +}; + +template static T isPow2(T value) { + assert(value > 0); + return (value & (value - 1)) == 0; +} + +// Round to the largest power of 2 that is <= value. +template static T floorPow2(T value) { + auto v = static_cast>(value); + return T(1) << (llvm::bit_width(v) - 1); +} + +// Round to the smallest power of 2 that is >= value. +template static T ceilPow2(T value) { + auto v = static_cast>(value); + return llvm::bit_ceil(v); +} + +// Find a factor of the number that is close to the given value and, if +// possible, is a power of 2. +template T findFactor(T number, T closeTo) { + closeTo = std::max(T(1), std::min(closeTo, number)); + + for (T max = number - closeTo + 1, i = 0; i < max; ++i) { + T up = closeTo + i; + if (auto pow2 = ceilPow2(up); number % pow2 == 0) { + return pow2; + } + if (i < closeTo - 1) { + T down = closeTo - i; + if (auto pow2 = floorPow2(down); pow2 != 1 && number % pow2 == 0) { + return pow2; + } + if (number % down == 0) { + return down; + } + } + if (number % up == 0) { + return up; + } + } + + return closeTo; +} + +template +static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr, + T minSize = static_cast(1)) { + T a = *aPtr; + T b = *bPtr; + assert(a >= b); + + if (a * b <= totalSize) { + return; + } + + bool aPow2 = isPow2(a); + bool bPow2 = isPow2(b); + double ratio = static_cast(a) / static_cast(b); + T x = static_cast(std::sqrt(totalSize)) * static_cast(std::sqrt(ratio)); + T y; + + if (aPow2) { + x = std::min(ceilPow2(x), std::min(a, floorPow2(totalSize))); + } else { + x = std::min(findFactor(a, x), std::min(a, totalSize)); + } + x = std::max(x, minSize); + if (bPow2) { + y = std::min(floorPow2(totalSize / x), b); + } else { + y = std::min(findFactor(b, totalSize / x), b); + } + if (y < minSize && a >= minSize && b >= minSize) { + if (auto newX = ceilPow2(totalSize / minSize); newX >= minSize) { + x = std::min(newX, a); + y = minSize; + } + } + + // Adjust x and y to get the closest ratio + auto distance = + std::abs(ratio - static_cast(x) / static_cast(y)); + auto ax = aPow2 ? x * 2 : findFactor(a, x * 2); + auto ay = std::max(bPow2 ? y / 2 : findFactor(b, y / 2), minSize); + + if (ax * ay <= totalSize && + std::abs(ratio - static_cast(ax) / static_cast(ay)) < + distance) { + x = ax; + y = ay; + } else { + ax = std::max(aPow2 ? x / 2 : findFactor(a, x / 2), minSize); + ay = bPow2 ? y * 2 : findFactor(b, y * 2); + if (ax * ay <= totalSize && + std::abs(ratio - static_cast(ax) / static_cast(ay)) < + distance) { + x = ax; + y = ay; + } + } + + *aPtr = x; + *bPtr = y; +} + +// Adjust tile sizes that meet the following conditions: +// 1. The product of all tiles is as close to totalSize as possible. +// 2. The new sizes are proportional to the initial sizes. +// 3. If the initial size is a power of 2, then the resulting size is a power of +// 2 either. Otherwise, the resulting size is a factor of the initial size +// and, if possible, is a power of 2. +template +static void adjustTiles(T totalSize, T *begin, T *end, + T minSize = static_cast(1), bool isSorted = false) { + assert((minSize & (minSize - 1)) == 0 && "minSize must be a power of 2"); + auto count = end - begin; + if (count == 0) { + return; + } + + if (count == 1) { + if (T a = *begin; isPow2(a)) { + *begin = std::min(std::max(ceilPow2(a), minSize), floorPow2(totalSize)); + } else { + *begin = std::min(findFactor(a, totalSize), minSize); + } + return; + } + + if (count > 2) { + SmallVector sorted; + SmallVector indices; + T *head; + T *tail; + + if (isSorted) { + head = begin; + tail = end; + } else { + SmallVector> pairs; + pairs.reserve(count); + for (unsigned i = 0; i < count; ++i) { + pairs.emplace_back(*(begin + i), i); + } + llvm::sort(pairs); + sorted.reserve(count); + indices.reserve(count); + for (auto &p : pairs) { + sorted.push_back(p.first); + indices.push_back(p.second); + } + head = sorted.data(); + tail = head + count; + } + + // Split the array in two. The first one consists of the 2 elements - the + // first one and the product of the rest. The second one is the rest. + T first[] = {*head, std::accumulate(head + 2, tail, *(head + 1), + std::multiplies<>())}; + adjustTiles(totalSize, first, first + 2, minSize, true); + adjustTiles(totalSize / *first, head + 1, tail, minSize, true); + *head = *first; + + if (!isSorted) { + for (unsigned i = 0; i < count; ++i) { + *(begin + indices[i]) = sorted[i]; + } + } + } else if (*begin >= *(end - 1)) { + adjustTwoTiles(totalSize, begin, end - 1, minSize); + } else { + adjustTwoTiles(totalSize, end - 1, begin, minSize); + } +} + +template +static void adjustTiles(T totalSize, SmallVector &tiles, + T minSize = static_cast(1)) { + adjustTiles(totalSize, tiles.begin(), tiles.end(), minSize); +} + +// Check recursively if the specified operation has an operand that +// depends on a result of a previous operation, matching the predicate. +template ::max()> +bool isOperandDependsOnOp(bool (*predicate)(Operation *), Operation *operation, + unsigned depth = 0) { + for (auto operand : operation->getOperands()) { + if (auto op = operand.getDefiningOp(); + op && + (predicate(op) || (depth < MaxDepth && + isOperandDependsOnOp(predicate, op, depth + 1)))) { + return true; + } + } + return false; +} + +// Check recursively if there are any operation, matching the predicate, that +// depends on the result of the specified operation. +template ::max()> +bool isOpDependsOnResult(bool (*predicate)(Operation *), Operation *operation, + unsigned depth = 0) { + for (auto res : operation->getResults()) { + for (auto u : res.getUsers()) { + if (predicate(u) || + (depth < MaxDepth && isOpDependsOnResult(predicate, u, depth + 1))) { + return true; + } + } + } + return false; +} +} // namespace mlir::gc +#endif diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index b3b1036c3..51067f06d 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -35,7 +35,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addNestedPass(createAddContextArg()); } - pm.addNestedPass(createIterativeTilingAndFusion()); + pm.addNestedPass(createGpuTilingAndFusion()); pm.addPass(bufferization::createEmptyTensorEliminationPass()); pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); diff --git a/test/mlir/unittests/ExecutionEngine/GPU/GpuOclRuntimeTest.cpp b/test/mlir/unittests/ExecutionEngine/GPU/GpuOclRuntimeTest.cpp index bf5a4092c..305b0aad0 100644 --- a/test/mlir/unittests/ExecutionEngine/GPU/GpuOclRuntimeTest.cpp +++ b/test/mlir/unittests/ExecutionEngine/GPU/GpuOclRuntimeTest.cpp @@ -61,17 +61,17 @@ module @test { )mlir"; constexpr char matmulAddStatic[] = R"mlir( -module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>>} { - func.func @entry(%arg0: memref<64x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<64x128xf16>) { - %0 = bufferization.to_tensor %arg0 restrict : memref<64x128xf16> - %1 = bufferization.to_tensor %arg1 restrict : memref<128x128xf16> - %2 = tensor.empty() : tensor<64x128xf16> +module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"GPU" : #dlti.target_device_spec<#dlti.dl_entry<"max_work_group_size", 16 : i64>>>} { + func.func @entry(%arg0: memref<128x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<128x256xf16>) { + %0 = bufferization.to_tensor %arg0 restrict : memref<128x256xf16> + %1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16> + %2 = tensor.empty() : tensor<128x256xf16> %cst = arith.constant 0.000000e+00 : f16 - %3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<64x128xf16>) -> tensor<64x128xf16> - %4 = linalg.matmul_transpose_b ins(%0, %1 : tensor<64x128xf16>, tensor<128x128xf16>) outs(%3 : tensor<64x128xf16>) -> tensor<64x128xf16> - %5 = tensor.empty() : tensor<64x128xf16> - %6 = linalg.add ins(%4, %0 : tensor<64x128xf16>, tensor<64x128xf16>) outs(%5 : tensor<64x128xf16>) -> tensor<64x128xf16> - bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<64x128xf16>, memref<64x128xf16>) -> () + %3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<128x256xf16>) -> tensor<128x256xf16> + %4 = linalg.matmul ins(%0, %1 : tensor<128x256xf16>, tensor<256x256xf16>) outs(%3 : tensor<128x256xf16>) -> tensor<128x256xf16> + %5 = tensor.empty() : tensor<128x256xf16> + %6 = linalg.add ins(%4, %0 : tensor<128x256xf16>, tensor<128x256xf16>) outs(%5 : tensor<128x256xf16>) -> tensor<128x256xf16> + bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<128x256xf16>, memref<128x256xf16>) -> () return } } @@ -167,7 +167,7 @@ template struct TestMatmulAdd : TestBase { gcGetOrReport(ctx.finish()); for (unsigned i = 0; i < size1; i++) { // std::cout << buf2[i] << " "; - assert(buf2[i] == 20496); + assert(buf2[i] == 21512); } // std::cout << "\n"; } @@ -220,7 +220,7 @@ TEST(GpuOclRuntime, TestAddDynamic) { } TEST(GpuOclRuntime, TestMatmulAddStatic) { - struct Test : TestMatmulAdd<64, 128> { + struct Test : TestMatmulAdd<128, 256> { void exec(std::shared_ptr &mod) override { assert(mod->isStatic); StaticExecutor<3> exec(mod); diff --git a/test/mlir/unittests/Transforms/CMakeLists.txt b/test/mlir/unittests/Transforms/CMakeLists.txt index 271c398ee..87fb27060 100644 --- a/test/mlir/unittests/Transforms/CMakeLists.txt +++ b/test/mlir/unittests/Transforms/CMakeLists.txt @@ -4,3 +4,7 @@ add_mlir_unittest(GCTransformsTests target_link_libraries(GCTransformsTests PRIVATE GcPasses) + +if(GC_ENABLE_IMEX) + add_subdirectory(GPU) +endif() \ No newline at end of file diff --git a/test/mlir/unittests/Transforms/GPU/CMakeLists.txt b/test/mlir/unittests/Transforms/GPU/CMakeLists.txt new file mode 100644 index 000000000..326948bb3 --- /dev/null +++ b/test/mlir/unittests/Transforms/GPU/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_unittest(GpuTransformsTests + GpuUtilsTest.cpp +) +target_link_libraries(GpuTransformsTests PRIVATE GcGpuPasses) +target_include_directories(GpuTransformsTests PRIVATE ${PROJECT_SOURCE_DIR}/lib) diff --git a/test/mlir/unittests/Transforms/GPU/GpuUtilsTest.cpp b/test/mlir/unittests/Transforms/GPU/GpuUtilsTest.cpp new file mode 100644 index 000000000..0f8868696 --- /dev/null +++ b/test/mlir/unittests/Transforms/GPU/GpuUtilsTest.cpp @@ -0,0 +1,94 @@ +//===- GpuUtilsTest.cpp - Tests for GpuUtils-------------------------------===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Transforms/GPU/GpuUtils.h" + +#include "gtest/gtest.h" + +TEST(testAdjustTiles, GputUtilsTest) { + bool print = false; + auto testAdjust = [print](int64_t totalSize, SmallVector &tiles, + const SmallVector &expected) { + if (print) { + std::cout << totalSize << ": ["; + for (unsigned i = 0; i < tiles.size(); i++) { + std::cout << tiles[i] << (i + 1 < tiles.size() ? ", " : ""); + } + std::cout << "] -> ["; + } + + gc::adjustTiles(totalSize, tiles); + + if (print) { + for (unsigned i = 0; i < tiles.size(); i++) { + std::cout << tiles[i] << (i + 1 < tiles.size() ? ", " : ""); + } + std::cout << "]" << std::endl; + } + + EXPECT_EQ(tiles, expected); + }; + auto test = [testAdjust](int64_t totalSize, SmallVector tiles, + SmallVector expected) { + if (tiles.size() != 2 || tiles[0] == tiles[1]) { + testAdjust(totalSize, tiles, expected); + return; + } + SmallVector reversed(tiles.rbegin(), tiles.rend()); + testAdjust(totalSize, tiles, expected); + std::reverse(expected.begin(), expected.end()); + testAdjust(totalSize, reversed, expected); + }; + + test(8, {1, 1}, {1, 1}); + test(8, {1, 2}, {1, 2}); + test(8, {2, 2}, {2, 2}); + test(8, {1, 4}, {1, 4}); + test(8, {1, 8}, {1, 8}); + test(8, {2, 8}, {2, 4}); + test(8, {1, 32}, {1, 8}); + test(8, {2, 32}, {1, 8}); + test(8, {4, 32}, {1, 8}); + test(8, {8, 32}, {2, 4}); + test(8, {16, 32}, {2, 4}); + test(8, {32, 32}, {2, 4}); + test(8, {64, 32}, {4, 2}); + test(8, {128, 32}, {4, 2}); + + test(8192, {1024, 1024}, {64, 128}); + test(8192, {32, 32}, {32, 32}); + test(8192, {16, 64}, {16, 64}); + test(8192, {8, 128}, {8, 128}); + test(8192, {4, 256}, {4, 256}); + test(8192, {2, 512}, {2, 512}); + test(8192, {1, 1024}, {1, 1024}); + test(8192, {512, 2}, {512, 2}); + test(8192, {256, 4}, {256, 4}); + test(8192, {128, 8}, {128, 8}); + test(8192, {64, 16}, {64, 16}); + test(8192, {32, 32}, {32, 32}); + + test(16384, {1, 1, 1}, {1, 1, 1}); + test(16384, {1, 2, 4}, {1, 2, 4}); + test(16384, {2, 4, 8}, {2, 4, 8}); + test(16384, {4, 8, 16}, {4, 8, 16}); + test(16384, {8, 16, 32}, {8, 16, 32}); + test(16384, {16, 32, 64}, {16, 32, 32}); + test(16384, {32, 64, 128}, {8, 32, 64}); + test(16384, {64, 128, 256}, {8, 32, 64}); + test(16384, {128, 256, 512}, {4, 64, 64}); + + test(16384, {7, 17, 111}, {7, 17, 111}); + test(16384, {7, 117, 111}, {7, 39, 37}); + test(16384, {6, 256, 512}, {1, 128, 128}); + test(16384, {60, 128, 512}, {4, 32, 128}); + test(16384, {119, 256, 512}, {7, 32, 64}); + test(16384, {109, 256, 512}, {109, 8, 16}); +}