From 6a4cd4af5e829ad06af85c6927796701f2142937 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Tue, 12 Nov 2024 10:54:47 +0000 Subject: [PATCH 1/7] Lower SLM to XeGPU - draft Signed-off-by: dchigarev --- CMakeLists.txt | 3 + include/gc/Utils/Log.h | 1 + lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 613 +++++++++++++++--- lib/gc/Transforms/GPU/Pipeline.cpp | 7 +- .../Transforms/GPU/linalg-to-xegpu-slm.mlir | 40 ++ .../XeGPU/f16_matmul_64x128_slm.mlir | 95 +++ 6 files changed, 668 insertions(+), 91 deletions(-) create mode 100644 test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir create mode 100644 test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir diff --git a/CMakeLists.txt b/CMakeLists.txt index 66af39632..24ffcd614 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,4 +140,7 @@ install(FILES ${PROJECT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake DESTINATION "lib/cmake/${PROJECT_NAME}" ) + +message("C++ Standard: ${CMAKE_CXX_STANDARD}") + ################################################################################ diff --git a/include/gc/Utils/Log.h b/include/gc/Utils/Log.h index 8755cf933..6bb765a34 100644 --- a/include/gc/Utils/Log.h +++ b/include/gc/Utils/Log.h @@ -68,6 +68,7 @@ static void debug(const char *fileName, int lineNum, Args... args) { #define gcLogD(...) mlir::gc::log::debug(__FILE__, __LINE__, __VA_ARGS__) #define gcLogE(...) \ mlir::gc::log::log(__FILE__, __LINE__, std::cerr, "ERROR", __VA_ARGS__) +#define gcRunD(...) if (mlir::gc::log::isDebugEnabled(__FILE__)) {__VA_ARGS__;} #endif } // namespace mlir::gc::log diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 0135d8fe3..584de6489 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -12,6 +12,7 @@ #include "gc/Transforms/Utils/StructuredOpMatcher.h" #include "gc/Transforms/Utils/ValueUtils.h" +#include "gc/Utils/Log.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -48,6 +49,83 @@ namespace gc { namespace { +// TODO: move these to utils +template +static Value createTypedVector(PatternRewriter &rewriter, Location loc, + ArrayRef values, Type elementType) { + mlir::VectorType vectorType = + mlir::VectorType::get({static_cast(values.size())}, elementType); + mlir::DenseElementsAttr denseAttr = + mlir::DenseElementsAttr::get(vectorType, values); + auto vector = + rewriter.create(loc, vectorType, denseAttr) + .getResult(); + return vector; +} + +static Value createIndexVector(PatternRewriter &rewriter, Location loc, + ArrayRef values) { + return createTypedVector(rewriter, loc, values, rewriter.getIndexType()); +} + +static Value createIndexConstant(PatternRewriter &rewriter, Location loc, + int64_t value) { + return rewriter.create(loc, value); +} + +static Value flattenMemref(PatternRewriter &rewriter, Location loc, + Value srcMemref) { + auto srcType = cast(srcMemref.getType()); + + assert(srcType && "Expected a memref type"); + assert(srcType.getRank() == 2 && "Expected a 2D memref"); + + int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; + + Value offset = rewriter.create(loc, 0); + Value size = rewriter.create(loc, flatSize); + Value stride = rewriter.create(loc, 1); + + // Use memref.reinterpret_cast to flatten the memref + auto flatMemRefType = MemRefType::get({flatSize}, srcType.getElementType(), + nullptr, srcType.getMemorySpace()); + auto flatMemref = + rewriter + .create(loc, flatMemRefType, srcMemref, + offset, size, stride) + .getResult(); + return flatMemref; +} + +static bool hasSharedMemSpace(mlir::Value memref) { + auto type = mlir::dyn_cast(memref.getType()); + if (!type) + return false; + + auto memSpace = type.getMemorySpace(); + if (!memSpace) + return false; + + if (auto gpuAttr = mlir::dyn_cast(memSpace)) + return gpuAttr.getValue() == mlir::gpu::AddressSpace::Private; + + if (auto intAttr = mlir::dyn_cast(memSpace)) + return intAttr.getValue() == + static_cast(mlir::gpu::AddressSpace::Private); + + return false; +} + +static Value createFullMask(PatternRewriter &rewriter, Location loc, + int64_t size) { + auto maskVal = createIndexConstant(rewriter, loc, 32); + mlir::VectorType maskVectorType = + mlir::VectorType::get({size}, rewriter.getI1Type()); + auto res = rewriter.create( + loc, maskVectorType, SmallVector({maskVal})); + return res.getResult(); +} + // Represents VNNI configuration for an operand. struct VnniConfig { int vnniFactor; @@ -83,10 +161,31 @@ struct TilesArray { SmallVector> tileMatrix; }; -static xegpu::TensorDescType getTensorDescType(llvm::ArrayRef shape, - mlir::Type elementType) { - return xegpu::TensorDescType::get(shape, elementType, /*array_length*/ 1, - /*boundary_check*/ true); +static xegpu::TensorDescType +getTensorDescType(llvm::ArrayRef shape, mlir::Type elementType, + std::optional sgMap = std::nullopt) { + if (!sgMap) { + // Assuming default tensor descriptor type (blocked & in global memory). + return xegpu::TensorDescType::get(shape, elementType, /*array_length=*/1, + /*boundary_check=*/true); + } + + auto descriptor = sgMap.value(); + if (auto scatterMap = dyn_cast(descriptor)) { + auto memSpace = scatterMap.getMemorySpace().getValue(); + int64_t chunkSize = scatterMap.getChunkSize().getInt(); + return xegpu::TensorDescType::get(shape, elementType, chunkSize, memSpace); + } + + if (auto blockMap = dyn_cast(descriptor)) { + auto memorySpace = blockMap.getMemorySpace().getValue(); + int64_t arrayLength = blockMap.getArrayLength().getInt(); + bool boundaryCheck = blockMap.getBoundaryCheck().getValue(); + return xegpu::TensorDescType::get(shape, elementType, arrayLength, + boundaryCheck, memorySpace); + } + + assert(false && "Unknown tensor descriptor type"); } // Return DPAS tile sizes if the gemm-like operation fits DPAS hardware. @@ -634,11 +733,10 @@ static SmallVector updateTilesOffsets(PatternRewriter &rewriter, // // The descriptor sub-tiles are ordered in row-major fashion with respect to the // whole load tile. -static SmallVector -createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src, - ArrayRef loadShape, - ArrayRef loadOffsets, ArrayRef descTile, - int arrayLength = 1, bool transpose = false) { +static SmallVector createNdDescriptorTiles( + PatternRewriter &rewriter, Location loc, Value src, + ArrayRef loadShape, ArrayRef loadOffsets, + ArrayRef descTile, int arrayLength = 1, bool transpose = false) { assert(arrayLength == 1 && "Array descriptors are not supported"); auto type = cast(src.getType()); @@ -686,30 +784,175 @@ createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src, return tiles; } -// Create coarse sub-tiles to be loaded by the current subgroup. +// Split a source into a series of 1D descriptor tiles. Each descriptor tile +// loads exactly 32 elements. // -// The shape to be loaded is split into the largest 2D loads supported -// by the hardware. +// The descriptors collectively load blocks of the 'loadShape2D' shape +// with the chunk sizes being 'tileSize2D'. // -// The load subgroup tiles are ordered in row-major fashion with respect to the -// source shape. -static SmallVector createCoarseDscTiles(PatternRewriter &rewriter, - Location loc, Value src, - ArrayRef sgTile, - bool isVnni, - bool transpose = false) { - assert(sgTile.size() <= 2 && +// The descriptor sub-tiles are ordered in row-major fashion with respect to the +// whole load tile. +static SmallVector createScatterDescriptorTiles( + PatternRewriter &rewriter, Location loc, Value flatMemref, + ArrayRef loadShape2D, ArrayRef tileSize2D, + ArrayRef memrefStrides, Value blockOffset) { + int64_t maxLoadSize = 32; + + assert(memrefStrides.size() == 2 && "Strides must be 2D"); + assert(memrefStrides[1] == 1 && "Only row-major strides are supported"); + assert(loadShape2D.size() == 2 && "Load shape must be 2D"); + assert(loadShape2D[0] * loadShape2D[1] % maxLoadSize == 0 && + "Load shape must be divisible by max load size"); + assert(tileSize2D.size() == 2 && "Descriptor tile must be 2D"); + assert(maxLoadSize % tileSize2D[1] == 0 && + "Descriptor tile must be divisible by max load size"); + + int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxLoadSize; + // How much rows of a tile fit into a single descriptor load + int64_t rowsPerLoad = maxLoadSize / tileSize2D[1]; + int64_t numColTiles = loadShape2D[1] / tileSize2D[1]; + + auto memrefType = dyn_cast(flatMemref.getType()); + + SmallVector> offsetShiftValues; + for (int colTile = 0; colTile < numColTiles; colTile++) { + offsetShiftValues.push_back(SmallVector()); + for (int i = 0; i < rowsPerLoad; i++) { + int64_t offset = i * memrefStrides[0]; + for (int j = 0; j < maxLoadSize / rowsPerLoad; j++) + offsetShiftValues[colTile].push_back(offset + j + + colTile * tileSize2D[1]); + } + } + + int64_t skipPerLoad = memrefStrides[0] * rowsPerLoad; + auto offsetPerLoad = + createIndexVector(rewriter, loc, SmallVector(32, skipPerLoad)); + + auto offsetVecType = VectorType::get({maxLoadSize}, rewriter.getIndexType()); + auto descType = getTensorDescType( + {maxLoadSize}, memrefType.getElementType(), + xegpu::ScatterTensorDescAttr::get( + rewriter.getContext(), xegpu::MemorySpace::SLM, /*chunkSize=*/1)); + + // Could have used 'vector.splat' here instead but it is not supported + // by 'imex::ConvertGPUXToSPIRVPass'. + SmallVector blockOffsetValues(32, blockOffset); + auto blockOffsetV = rewriter.create( + loc, offsetVecType, blockOffsetValues); + + SmallVector tiles; + for (int i = 0; i < numColTiles; i++) { + auto offsetsShift = createIndexVector(rewriter, loc, offsetShiftValues[i]); + auto offsets0 = + rewriter.create(loc, blockOffsetV, offsetsShift); + + auto desc = + rewriter + .create(loc, descType, flatMemref, offsets0) + .getResult(); + tiles.push_back(desc); + for (int j = maxLoadSize; j < loadShape2D[0] * loadShape2D[1] / numColTiles; + j += maxLoadSize) { + auto newTile = rewriter + .create( + loc, descType, tiles.back(), offsetPerLoad) + .getResult(); + tiles.push_back(newTile); + } + } + + // Reorder the tiles into a row-major format by transposing the generated + // layout + SmallVector transposedTiles; + int numRowTiles = tiles.size() / numColTiles; + + for (int rowTile = 0; rowTile < numRowTiles; rowTile += numLoadsPerTile) + for (int colTile = 0; colTile < numColTiles; colTile++) + for (int loadOffset = 0; loadOffset < numLoadsPerTile; loadOffset++) { + int newIdx = rowTile + colTile * numRowTiles + loadOffset; + transposedTiles.push_back(tiles[newIdx]); + } + + return transposedTiles; +} + +// Creates descriptors to load from SLM. +// +// The function returns a vector of 1D descriptor tiles that load the specified +// 2D shape from the SLM. +static SmallVector createSLMDescTiles(PatternRewriter &rewriter, + Location loc, Value src, + ArrayRef loadShape, + ArrayRef descTile) { + assert(loadShape.size() <= 2 && "Require at most 2D tile size for eltwise lowering"); - // Ensure that load is 2D. - // TODO: Add support for 1D loads. - SmallVector sgTile2D{sgTile}; - if (sgTile.size() == 1) - sgTile2D.push_back(1); + auto srcType = src.getType().cast(); + assert(srcType.getRank() == 2 && "Expected a 2D memref"); + auto elemByteWidth = srcType.getElementType().getIntOrFloatBitWidth() / 8; + + SmallVector memrefStrides; + Value blockOffset; + + // 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the + // GPU kernel. We have to merge the subview offsets into the descriptor + // offset. + if (auto subView = dyn_cast(src.getDefiningOp())) { + auto xIntOffs = subView.getOffsets()[0]; + auto yIntOffs = subView.getOffsets()[1]; + + // compute 'blockOffset' (beginning of the subview block in the original + // flat memref) + auto rowStride = + cast(subView.getOperand(0).getType()).getShape()[1]; + auto rowStrideValue = + rewriter.create(loc, rowStride); + + auto rowBlockOffset = + rewriter.create(loc, xIntOffs, rowStrideValue) + .getResult(); + blockOffset = rewriter.create(loc, rowBlockOffset, yIntOffs) + .getResult(); + + memrefStrides = {rowStride, 1}; + src = subView.getOperand(0); + } else { + // If the source is not a subview, then the blockOffset is 0 + blockOffset = rewriter.create(loc, 0); + memrefStrides = {srcType.getShape()[1], 1}; + } + + // Scatter descriptors only work with 1D memrefs + src = flattenMemref(rewriter, loc, src); - auto type = cast(src.getType()); - auto elemByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + return createScatterDescriptorTiles( + rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape, + /*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides, + /*blockOffset=*/blockOffset); +} +static SmallVector createDescriptorTiles( + PatternRewriter &rewriter, Location loc, Value src, + ArrayRef loadShape, ArrayRef descTile, + std::optional> loadOffsets = std::nullopt, + int arrayLength = 1, bool transpose = false) { + + if (hasSharedMemSpace(src)) { + assert(!transpose && "Transpose is not supported for shared memory"); + assert(arrayLength == 1 && + "Array descriptors are not supported for shared memory"); + assert(!loadOffsets && "Load offsets are not supported for shared memory"); + return createSLMDescTiles(rewriter, loc, src, loadShape, descTile); + } + return createNdDescriptorTiles( + rewriter, loc, src, loadShape, + loadOffsets.value_or(SmallVector{0, 0}), descTile, arrayLength, + transpose); +} + +SmallVector determine2DTileSize(ArrayRef totalShape, + bool isVnni, int64_t elemByteWidth) { // TODO: Fetch actual list of supported load configs. int64_t maxHeight = 32; int64_t maxWidth = 64 / elemByteWidth; @@ -717,22 +960,34 @@ static SmallVector createCoarseDscTiles(PatternRewriter &rewriter, // TODO: Make the VNNI-factor flexible. if (isVnni) maxWidth /= 2; - int64_t maxArrayLength = 4; - int64_t sgLoadRows = std::min(sgTile2D[0], maxHeight); - int64_t sgLoadCols = std::min(sgTile2D[1], maxWidth); - int64_t arrayLength = std::min(maxWidth / sgLoadCols, maxArrayLength); - // In case of partial fit, load only single tile. - // NOLINTBEGIN - if (maxWidth % sgLoadCols != 0 || arrayLength != 4 || arrayLength != 2) - arrayLength = 1; - // TODO: Add variable array_length support. - arrayLength = 1; - // NOLINTEND + int64_t sgLoadRows = std::min(totalShape[0], maxHeight); + int64_t sgLoadCols = std::min(totalShape[1], maxWidth); - return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0}, - {sgLoadRows, sgLoadCols}, arrayLength, - transpose); + return SmallVector{sgLoadRows, sgLoadCols}; +} + +// Create coarse sub-tiles to be loaded by the current subgroup. +// +// The shape to be loaded is split into the largest 2D loads supported +// by the hardware. +// +// The load subgroup tiles are ordered in row-major fashion with respect to the +// source shape. +static std::tuple, SmallVector> +createCoarseDscTiles(PatternRewriter &rewriter, Location loc, Value src, + ArrayRef sgTile, bool isVnni, + bool transpose = false) { + auto type = cast(src.getType()); + auto elemByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + + auto tileSize = + determine2DTileSize(sgTile, isVnni, /*elementByteWidth=*/elemByteWidth); + auto descriptors = + createDescriptorTiles(rewriter, loc, src, sgTile, tileSize, std::nullopt, + /*array_length=*/1, transpose); + + return std::make_tuple(descriptors, tileSize); } // Return vector type with specified VNNI shape. @@ -749,13 +1004,20 @@ static VectorType getVnniVector(ArrayRef shape, Type elementType, static SmallVector loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, xegpu::CachePolicyAttr hint, + std::optional> tileShape = std::nullopt, std::optional vnniConf = std::nullopt, DenseI64ArrayAttr transpose = nullptr, IntegerAttr transpose_bit = nullptr) { // Assume all tiles have the same shape. auto tileType = cast(loadTiles[0].getType()); + auto tileShapeValue = tileShape.value_or(tileType.getShape()); assert(llvm::all_of(loadTiles, - [&](Value tile) { return tile.getType() == tileType; }) && + [&](Value tile) { + auto xeTile = + cast(tile.getType()); + return xeTile && xeTile == tileType && + tileShapeValue.equals(xeTile.getShape()); + }) && "All load tiles must have the same type."); VectorType vecLoadType = @@ -783,6 +1045,189 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, return loadVec; } +// Load from scatter 1D descriptors and return a vector of 2D tiles +// with the shape of 'tileShape'. +static SmallVector +loadScatterDescTiles(PatternRewriter &rewriter, Location loc, + ValueRange loadTiles, xegpu::CachePolicyAttr hint, + ArrayRef tileShape, + std::optional vnniConf = std::nullopt, + DenseI64ArrayAttr transpose = nullptr, + IntegerAttr transpose_bit = nullptr) { + int64_t elementsPerLoad = 32; + + // Assume all tiles have the same shape. + auto tileType = cast(loadTiles[0].getType()); + assert(llvm::all_of(loadTiles, + [&](Value tile) { return tile.getType() == tileType; }) && + "All load tiles must have the same type."); + assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D"); + assert(tileType.getShape()[0] == elementsPerLoad && + "Scatter tiles must have 32 elements"); + assert(!vnniConf && "VNNI not supported for scatter loads"); + assert(!transpose && "Transpose is not supported for scatter loads"); + assert(!transpose_bit && "Transpose is not supported for scatter loads"); + + int64_t totalLoadElems = tileType.getShape()[0] * loadTiles.size(); + assert(totalLoadElems % elementsPerLoad == 0 && + "Total load size must be multiple of 32"); + assert(tileShape[0] * tileShape[1] % elementsPerLoad == 0 && + "Tile shape must be multiple of 32"); + + int64_t loadsPerTile = tileShape[0] * tileShape[1] / elementsPerLoad; + int64_t totalNumLoads = totalLoadElems / elementsPerLoad; + auto mask = createFullMask(rewriter, loc, elementsPerLoad); + + SmallVector result; + auto elementType = tileType.getElementType(); + SmallVector accumValues( + loadsPerTile * elementsPerLoad, + dyn_cast(rewriter.getZeroAttr(elementType))); + + VectorType accumVectorType = + VectorType::get({loadsPerTile, elementsPerLoad}, elementType); + VectorType loadVectorType = VectorType::get({elementsPerLoad}, elementType); + + for (int64_t tileIdx = 0; tileIdx < totalNumLoads; tileIdx += loadsPerTile) { + // Accumulator vector for the current tile (its number of elements equals to + // tileShape) HACK: we first create a flat vector of zeros and then cast it + // to the 2D shape. Otherwise 'imex::ConvertGPUXToSPIRVPass' fails. + auto accumVector = + createTypedVector(rewriter, loc, accumValues, elementType); + accumVector = + rewriter.create(loc, accumVectorType, accumVector); + + // Load from descriptors to the accumulator vector. + for (int64_t loadIdx = 0; loadIdx < loadsPerTile; loadIdx++) { + auto loadOp = rewriter.create( + loc, loadVectorType, loadTiles[tileIdx + loadIdx], /*mask=*/mask, + /*transpose=*/nullptr, + // Do we need those for SLM? + /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); + + accumVector = rewriter.create( + loc, loadOp.getResult(), accumVector, SmallVector{loadIdx}); + } + + if (tileShape[1] == elementsPerLoad) { + // No need to reshape the accumulator vector. + result.push_back(accumVector); + continue; + } + + // Cast the accumulator vector to the 'tileShape' + auto flatTile = rewriter.create( + loc, VectorType::get({tileShape[0] * tileShape[1]}, elementType), + accumVector); + auto loadedTile = rewriter.create( + loc, VectorType::get({tileShape[0], tileShape[1]}, elementType), + flatTile); + result.push_back(loadedTile); + } + +#ifndef NDEBUG + // verify correctness + int64_t elemsLoaded = 0; + for (auto v : result) { + auto shape = cast(v.getType()).getShape(); + elemsLoaded += shape[0] * shape[1]; + } + assert(elemsLoaded == totalLoadElems && + "Loaded number of elements must match the total number of elements"); +#endif + + return result; +} + +// Load from descriptors and return a vector of 2D tiles with the shape of +// 'tileShape'. +static SmallVector +loadDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, + xegpu::CachePolicyAttr hint, + std::optional> tileShape = std::nullopt, + std::optional vnniConf = std::nullopt, + DenseI64ArrayAttr transpose = nullptr, + IntegerAttr transpose_bit = nullptr) { + auto tile = dyn_cast(loadTiles[0].getType()); + if (tile.getMemorySpace() == xegpu::MemorySpace::SLM) { + assert(tileShape.has_value() && + "tileShape must be provided for scatter loads"); + return loadScatterDescTiles(rewriter, loc, loadTiles, hint, + tileShape.value(), vnniConf, transpose, + transpose_bit); + } + return loadNdDescTiles(rewriter, loc, loadTiles, hint, tileShape, vnniConf, + transpose, transpose_bit); +} + +static void storeNdDescTiles(PatternRewriter &rewriter, Location loc, + SmallVector &results, ValueRange storeTiles, + xegpu::CachePolicyAttr hint) { + for (size_t i = 0; i < storeTiles.size(); i++) { + rewriter.create(loc, results[i], storeTiles[i], + /*l1_hint=*/hint, + /*l2_hint=*/hint, + /*l3_hint=*/hint); + } +} + +static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc, + SmallVector &results, + ValueRange storeTiles, + xegpu::CachePolicyAttr hint) { + int64_t elementsPerStore = 32; + + auto tileType = cast(storeTiles[0].getType()); + assert(llvm::all_of(storeTiles, + [&](Value tile) { return tile.getType() == tileType; }) && + "All load tiles must have the same type."); + assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D"); + assert(tileType.getShape()[0] == elementsPerStore && + "Scatter tiles must have 32 elements"); + + auto mask = createFullMask(rewriter, loc, elementsPerStore); + int64_t descIdx = 0; + + for (auto vec : results) { + auto vecType = dyn_cast(vec.getType()); + auto vecShape = vecType.getShape(); + assert(vecShape.size() == 2 && "Expected 2D vector"); + assert(vecShape[0] * vecShape[1] % elementsPerStore == 0 && + "Vector shape must be divisible by load size"); + + // Flatten the vector to 1D + auto flatVec = rewriter.create( + loc, + VectorType::get({vecShape[0] * vecShape[1]}, vecType.getElementType()), + vec); + // Extract slices of 32 size from 'flatVec' and store them + for (int64_t loadChunkIdx = 0; loadChunkIdx < vecShape[0] * vecShape[1]; + loadChunkIdx += elementsPerStore) { + auto toStore = rewriter.create( + loc, flatVec, /*offsets=*/SmallVector({loadChunkIdx}), + /*sizes=*/SmallVector({elementsPerStore}), + /*strides=*/SmallVector({1})); + rewriter.create(loc, toStore, storeTiles[descIdx], + /*mask=*/mask, + /*transpose=*/nullptr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, + /*l3_hint=*/hint); + descIdx++; + } + } +} + +static void storeDescTiles(PatternRewriter &rewriter, Location loc, + SmallVector &results, ValueRange storeTiles, + xegpu::CachePolicyAttr hint) { + auto tile = dyn_cast(storeTiles[0].getType()); + if (tile.getMemorySpace() == xegpu::MemorySpace::SLM) { + return storeScatterDescTiles(rewriter, loc, results, storeTiles, hint); + } + return storeNdDescTiles(rewriter, loc, results, storeTiles, hint); +} + // Splits loaded tiles of a larger 2D tile into individual subtiles and places // them in their corresponding positions with respect to the original large // tile. @@ -986,19 +1431,20 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, int dimK = typeA.getShape().back(); // Create C sub-tiles. - auto dpasTypeC = - getTensorDescType({dpasTileM, dpasTileN}, typeC.getElementType()); - SmallVector tilesC = createDescriptorTiles( - rewriter, loc, matC, typeC.getShape(), {0, 0}, dpasTypeC.getShape()); + SmallVector dpasShapeC({dpasTileM, dpasTileN}); + + auto tilesC = + createDescriptorTiles(rewriter, loc, matC, typeC.getShape(), dpasShapeC); // Load C sub-tiles. // Fetch the inital values of the output accumulator. SmallVector loadVecC = - loadNdDescTiles(rewriter, loc, tilesC, readCacheHint); + loadDescTiles(rewriter, loc, tilesC, readCacheHint, + /*resultShape=*/dpasShapeC, /*vnniConf=*/std::nullopt, + /*transpose=*/nullptr, /*transpose_bit=*/nullptr); // DPAS only works with F32 accumulators. - auto dpasResType = - VectorType::get(dpasTypeC.getShape(), FloatType::getF32(ctx)); + auto dpasResType = VectorType::get(dpasShapeC, FloatType::getF32(ctx)); // Extend the accumulation values if needed. auto convOutPrecision = !typeC.getElementType().isF32(); @@ -1043,11 +1489,11 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, } // Create A sub-tiles. - SmallVector tilesA = + auto [tilesA, tilesShapeA] = createCoarseDscTiles(rewriter, loc, matA, {dimM, kTile}, /*isVnni=*/true); // Create B sub-tiles. - SmallVector tilesB = + auto [tilesB, tilesShapeB] = createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true, transposeB); @@ -1171,7 +1617,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Load A sub-tiles. SmallVector loadVecA = - loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); + loadDescTiles(rewriter, loc, tilesA, readCacheHint, tilesShapeA, + /*vnniConf=*/std::nullopt, /*transpose=*/nullptr, + /*transpose_bit=*/nullptr); auto tileTypeA = cast(tilesA[0].getType()); DenseI64ArrayAttr transpose = nullptr; @@ -1184,8 +1632,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Load B sub-tiles. SmallVector loadVecB = - loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB, - transpose, transpose_bit); + loadDescTiles(rewriter, loc, tilesB, readCacheHint, tilesShapeB, + vnniConfB, transpose, transpose_bit); auto tileTypeB = cast(tilesB[0].getType()); // Update offsets of the input tiles. @@ -1280,8 +1728,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Truncate the result values if needed. if (convOutPrecision) { - auto truncType = - VectorType::get(dpasTypeC.getShape(), typeC.getElementType()); + auto truncType = VectorType::get(dpasShapeC, typeC.getElementType()); for (size_t i = 0; i < results.size(); i++) { auto truncOp = rewriter.create(loc, truncType, results[i]); @@ -1289,16 +1736,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, } } - // Write back the final C sub-tiles results to the output buffer. - SmallVector storeOps; - for (size_t i = 0; i < tilesC.size(); i++) { - auto storeOp = - rewriter.create(loc, results[i], tilesC[i], - /*l1_hint=*/writeCacheHint, - /*l2_hint=*/writeCacheHint, - /*l3_hint=*/writeCacheHint); - storeOps.push_back(storeOp); - } + storeDescTiles(rewriter, loc, results, tilesC, writeCacheHint); rewriter.eraseOp(linalgOp); @@ -1312,21 +1750,28 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp, auto ctx = linalgOp.getContext(); auto output = linalgOp.getDpsInits()[0]; - auto outputShape = cast(output.getType()).getShape(); + auto outputType = cast(output.getType()); + auto outputShape = outputType.getShape(); + auto outputByteWidth = outputType.getElementTypeBitWidth() / 8; + auto tileShape = + determine2DTileSize(outputShape, /*isVnni=*/false, outputByteWidth); // Create descriptors and load values for all inputs. SmallVector> loadedInputs; for (auto input : linalgOp.getDpsInputs()) { - SmallVector inputTiles = createCoarseDscTiles( - rewriter, loc, input, outputShape, /*isVnni=*/false); + SmallVector inputTiles = + createDescriptorTiles(rewriter, loc, input, outputShape, tileShape); + SmallVector loadedVals = - loadNdDescTiles(rewriter, loc, inputTiles, /*hint=*/nullptr); + loadDescTiles(rewriter, loc, inputTiles, /*hint=*/nullptr, tileShape, + /*vnniConf=*/std::nullopt, + /*transpose=*/nullptr, /*transpose_bit=*/nullptr); loadedInputs.push_back(loadedVals); } // Extract SIMD sized sub-tiles from loaded tiles. // TODO: Fetch SIMD sizes from target descriptor. - int maxSizeSIMD = 256; + int64_t maxSizeSIMD = 256; auto loadShape = cast(loadedInputs[0][0].getType()).getShape(); // For sake of n-D loads and store, the vectorized operations are kept in 2D // shape. The loaded tiles might be larger than what SIMD units can handle. @@ -1368,18 +1813,13 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp, // Output descriptors for later stores. SmallVector outputTiles = createDescriptorTiles( - rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols}); + rewriter, loc, output, outputShape, {subTileRows, subTileCols}); // Store results. auto writeCacheHint = xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK); - for (size_t i = 0; i < outputTiles.size(); i++) { - rewriter.create(loc, results[i], outputTiles[i], - /*l1_hint=*/writeCacheHint, - /*l2_hint=*/writeCacheHint, - /*l3_hint=*/writeCacheHint); - } + storeDescTiles(rewriter, loc, results, outputTiles, writeCacheHint); rewriter.eraseOp(linalgOp); return success(); @@ -1521,13 +1961,14 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp, } // Extract SIMD sized sub-tiles - int maxSizeSIMD = 256; - int64_t subTileCols = outputShape[1]; - int64_t subTileRows = std::min(outputShape[0], maxSizeSIMD / subTileCols); + int64_t maxSizeSIMD = hasSharedMemSpace(output) ? 32 : 256; + int64_t subTileCols = std::min(outputShape[1], maxSizeSIMD); + int64_t subTileRows = + std::min(outputShape[0], std::max(maxSizeSIMD / subTileCols, 1L)); // Output descriptors for later stores. SmallVector outputTiles = createDescriptorTiles( - rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols}); + rewriter, loc, output, outputShape, {subTileRows, subTileCols}); SmallVector results; for (size_t i = 0; i < outputTiles.size(); i++) { @@ -1548,12 +1989,8 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp, // Store results. auto writeCacheHint = xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK); - for (size_t i = 0; i < outputTiles.size(); i++) { - rewriter.create(loc, results[i], outputTiles[i], - /*l1_hint=*/writeCacheHint, - /*l2_hint=*/writeCacheHint, - /*l3_hint=*/writeCacheHint); - } + + storeDescTiles(rewriter, loc, results, outputTiles, writeCacheHint); rewriter.eraseOp(linalgOp); diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index e17f23140..3b4bf5525 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -63,14 +63,16 @@ void populateGPUPipeline(OpPassManager &pm, pm.addPass(createBufferizationToMemRefPass()); pm.addNestedPass(createForallToParallelLoopPass()); + pm.addNestedPass(createGpuMapParallelLoopsPass()); + pm.addNestedPass(createParallelLoopToGpuPass()); + pm.addPass(createCanonicalizerPass()); + pm.addNestedPass(createAllocsToSLM()); pm.addNestedPass(createLinalgToXeGPU( {/*kTile=*/16, /*stages=*/1, /*dpasTiles=*/{8, 16, 16}})); pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addPass(xegpu::createXeGPUFoldAliasOps()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addNestedPass(createGpuMapParallelLoopsPass()); - pm.addNestedPass(createParallelLoopToGpuPass()); imex::InsertGPUAllocsOptions insertGPUAllocsOption{ /*clientAPI*/ "opencl", /*inRegions*/ false, @@ -78,7 +80,6 @@ void populateGPUPipeline(OpPassManager &pm, pm.addNestedPass( imex::createInsertGPUAllocsPass(insertGPUAllocsOption)); pm.addPass(createGpuKernelOutliningPass()); - pm.addPass(createCanonicalizerPass()); pm.addPass(imex::createSetSPIRVCapabilitiesPass()); pm.addNestedPass( imex::createSetSPIRVAbiAttributePass("opencl")); diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir new file mode 100644 index 000000000..e71149860 --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir @@ -0,0 +1,40 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +// TODO: write CHECK directives + +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 16)> + +func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>) { + %cst = arith.constant 0.000000e+00 : f16 + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) { + %x_group_idx = affine.apply #map(%arg5) + %y_group_idx = affine.apply #map(%arg6) + + %x_thread_idx = affine.apply #map1(%arg8) + %y_thread_idx = affine.apply #map1(%arg9) + + %x_global_idx = arith.addi %x_group_idx, %x_thread_idx : index + %y_global_idx = arith.addi %y_group_idx, %y_thread_idx : index + + %a_subview = memref.subview %arg0[%x_global_idx, 0] [16, 1024] [1, 1] : memref<128x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>> + %b_subview = memref.subview %arg1[0, %y_global_idx] [1024, 16] [1, 1] : memref<1024x1024xf16> to memref<1024x16xf16, strided<[1024, 1], offset: ?>> + + %slm_buff = memref.alloc() : memref<64x256xf16, 3> + %slm_subview = memref.subview %slm_buff[%x_thread_idx, %y_thread_idx] [16, 16] [1, 1] : memref<64x256xf16, 3> to memref<16x16xf16, strided<[256, 1], offset: ?>, 3> + + linalg.fill ins(%cst : f16) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) + linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) + + %a_add_subview = memref.subview %arg0[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> + %out_subview = memref.subview %arg2[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> + + linalg.add ins(%slm_subview, %a_add_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>, memref<16x16xf16, strided<[1024, 1], offset: ?>>) outs(%out_subview : memref<16x16xf16, strided<[1024, 1], offset: ?>>) + gpu.terminator + } + return +} diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir new file mode 100644 index 000000000..76bebe6c4 --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir @@ -0,0 +1,95 @@ +// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s + +module @fragment_name { + func.func @entry(%0: tensor<64x128xf16>, %1: tensor<128x128xf16>, %2: tensor<64x128xf16>, %res: tensor<64x128xf16>) -> tensor<64x128xf16> { + %3 = tensor.empty() : tensor<128x128xf16> + %4 = tensor.empty() : tensor<64x128xf16> + %cst = arith.constant 0.000000e+00 : f16 + %5 = linalg.fill ins(%cst : f16) outs(%4 : tensor<64x128xf16>) -> tensor<64x128xf16> + %6 = linalg.matmul ins(%0, %1 : tensor<64x128xf16>, tensor<128x128xf16>) outs(%5 : tensor<64x128xf16>) -> tensor<64x128xf16> + %7 = tensor.empty() : tensor<64x128xf16> + %8 = linalg.add ins(%6, %2 : tensor<64x128xf16>, tensor<64x128xf16>) outs(%7 : tensor<64x128xf16>) -> tensor<64x128xf16> + %9 = tensor.empty() : tensor<64x128xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %10 = linalg.fill ins(%cst_0 : f16) outs(%9 : tensor<64x128xf16>) -> tensor<64x128xf16> + %11 = linalg.max ins(%8, %10 : tensor<64x128xf16>, tensor<64x128xf16>) outs(%res : tensor<64x128xf16>) -> tensor<64x128xf16> + return %11 : tensor<64x128xf16> + } + + func.func @get_value(%i : index, %j : index, %even_val : f16, %odd_val : f16) -> f16 { + %int0 = arith.index_cast %i : index to i32 + + %c2 = arith.constant 2 : i32 + %remeinder = arith.remui %int0, %c2 : i32 + %c0i = arith.constant 0 : i32 + %is_even = arith.cmpi eq, %remeinder, %c0i : i32 + + %val = scf.if %is_even -> (f16) { + scf.yield %even_val : f16 + } else { + scf.yield %odd_val : f16 + } + return %val : f16 + } + + // generates asymmetric tensor + func.func @generate_t(%even_val : f16, %odd_val : f16) -> tensor<64x128xf16> { + %0 = tensor.generate { + ^bb0(%i : index, %j : index): + %val = func.call @get_value(%i, %j, %even_val, %odd_val) : (index, index, f16, f16) -> f16 + tensor.yield %val : f16 + } : tensor<64x128xf16> + return %0 : tensor<64x128xf16> + } + + func.func @generate_t_wide(%even_val : f16, %odd_val : f16) -> tensor<128x128xf16> { + %0 = tensor.generate { + ^bb0(%i : index, %j : index): + %val = func.call @get_value(%i, %j, %even_val, %odd_val) : (index, index, f16, f16) -> f16 + tensor.yield %val : f16 + } : tensor<128x128xf16> + return %0 : tensor<128x128xf16> + } + + func.func @main() { + %a0 = arith.constant 0.1 : f16 + %b0 = arith.constant 0.2 : f16 + %0 = call @generate_t(%a0, %b0) : (f16, f16) -> tensor<64x128xf16> + + %a1 = arith.constant 0.3 : f16 + %b1 = arith.constant 0.4 : f16 + %1 = call @generate_t_wide(%a1, %b1) : (f16, f16) -> tensor<128x128xf16> + + %a2 = arith.constant 0.5 : f16 + %b2 = arith.constant 0.6 : f16 + %2 = call @generate_t(%a2, %b2) : (f16, f16) -> tensor<64x128xf16> + + %3 = arith.constant dense<0.0> : tensor<64x128xf16> + %gpu_res = call @entry(%0, %1, %2, %3) : (tensor<64x128xf16>, tensor<128x128xf16>, tensor<64x128xf16>, tensor<64x128xf16>) -> (tensor<64x128xf16>) + %slice = tensor.extract_slice %gpu_res[0, 0][16, 16][1, 1] : tensor<64x128xf16> to tensor<16x16xf16> + %cast = tensor.cast %slice : tensor<16x16xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + return +} + +func.func private @printMemrefF16(%ptr : tensor<*xf16>) +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 2 offset = 0 sizes = [16, 16] strides = [128, 1] data = +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625], +// CHECK-NEXT: [4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047, 4.98047], +// CHECK-NEXT: [9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625, 9.5625] From f63e8f6054a616db87d9f0e56d4a1f2eeeb61e5d Mon Sep 17 00:00:00 2001 From: dchigarev Date: Wed, 13 Nov 2024 16:08:54 +0000 Subject: [PATCH 2/7] Add transform test Signed-off-by: dchigarev --- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 2 +- .../Transforms/GPU/linalg-to-xegpu-slm.mlir | 75 ++++++++++++++++--- .../XeGPU/f16_matmul_64x128_slm.mlir | 1 + 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 584de6489..38ab881a8 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -888,7 +888,7 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, assert(loadShape.size() <= 2 && "Require at most 2D tile size for eltwise lowering"); - auto srcType = src.getType().cast(); + auto srcType = cast(src.getType()); assert(srcType.getRank() == 2 && "Expected a 2D memref"); auto elemByteWidth = srcType.getElementType().getIntOrFloatBitWidth() / 8; diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir index e71149860..8b7e4646a 100644 --- a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir @@ -1,22 +1,26 @@ -// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s - -// TODO: write CHECK directives +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s #map = affine_map<(d0) -> (d0 * 64)> #map1 = affine_map<(d0) -> (d0 * 16)> func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>) { + // CHECK: %[[loadAccumMatmul:.+]] = arith.constant dense<0.000000e+00> : vector<4x32xf16> + // CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<32xf16> + // CHECK: %[[colTileShift:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271]> : vector<32xindex> + // CHECK: %[[loadOffset:.+]] = arith.constant dense<512> : vector<32xindex> %cst = arith.constant 0.000000e+00 : f16 %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c16 = arith.constant 16 : index - gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) { - %x_group_idx = affine.apply #map(%arg5) - %y_group_idx = affine.apply #map(%arg6) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) { + %x_group_idx = affine.apply #map(%arg3) + %y_group_idx = affine.apply #map(%arg4) - %x_thread_idx = affine.apply #map1(%arg8) - %y_thread_idx = affine.apply #map1(%arg9) + // CHECK: %[[X_THREAD_IDX:.+]] = affine.apply #map1(%arg6) + // CHECK: %[[Y_THREAD_IDX:.+]] = affine.apply #map1(%arg7) + %x_thread_idx = affine.apply #map1(%arg6) + %y_thread_idx = affine.apply #map1(%arg7) %x_global_idx = arith.addi %x_group_idx, %x_thread_idx : index %y_global_idx = arith.addi %y_group_idx, %y_thread_idx : index @@ -24,16 +28,63 @@ func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg %a_subview = memref.subview %arg0[%x_global_idx, 0] [16, 1024] [1, 1] : memref<128x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>> %b_subview = memref.subview %arg1[0, %y_global_idx] [1024, 16] [1, 1] : memref<1024x1024xf16> to memref<1024x16xf16, strided<[1024, 1], offset: ?>> + // CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<64x256xf16, 3> %slm_buff = memref.alloc() : memref<64x256xf16, 3> + // CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .* + // CHECK: %[[SLM_X_OFF:.+]] = arith.muli %[[X_THREAD_IDX]], %c256 : index + // CHECK: %[[SLM_THREAD_OFF:.+]] = arith.addi %[[SLM_X_OFF]], %[[Y_THREAD_IDX]] : index + // CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c16384], strides: [%c1] : memref<64x256xf16, 3> to memref<16384xf16, 3> %slm_subview = memref.subview %slm_buff[%x_thread_idx, %y_thread_idx] [16, 16] [1, 1] : memref<64x256xf16, 3> to memref<16x16xf16, strided<[256, 1], offset: ?>, 3> + // CHECK: %[[SLM_THREAD_OFF_V:.+]] = vector.splat %[[SLM_THREAD_OFF]] : vector<32xindex> + // CHECK: %[[DESC_OFFSET0:.+]] = arith.addi %[[SLM_THREAD_OFF_V]], %[[colTileShift]] : vector<32xindex> + // CHECK: %[[ROOT_DESC:.+]] = xegpu.create_tdesc %[[FLAT_SLM]], %[[DESC_OFFSET0]] : memref<16384xf16, 3>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[FILL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]] + // CHECK: %[[FILL_DESC2:.+]] = xegpu.update_offset %[[FILL_DESC1]], %[[loadOffset]] + // CHECK-COUNT-5: xegpu.update_offset + + // CHECK: xegpu.store %[[ZERO]], %[[ROOT_DESC]] + // CHECK: xegpu.store %[[ZERO]], %[[FILL_DESC1]] + // CHECK-COUNT-6: xegpu.store linalg.fill ins(%cst : f16) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) - linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) - %a_add_subview = memref.subview %arg0[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> - %out_subview = memref.subview %arg2[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> + // CHECK: %[[MATMUL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]] + // CHECK: %[[MATMUL_DESC2:.+]] = xegpu.update_offset %[[MATMUL_DESC1]], %[[loadOffset]] + // CHECK-COUNT-5: xegpu.update_offset + + // CHECK: %[[MATMUL_LOAD0:.+]] = xegpu.load %[[ROOT_DESC]] + // CHECK-NEXT: %[[loadAccumMatmul1:.+]] = vector.insert %[[MATMUL_LOAD0]], %[[loadAccumMatmul]] [0] + // CHECK-NEXT: %[[MATMUL_LOAD1:.+]] = xegpu.load %[[MATMUL_DESC1]] + // CHECK-NEXT: %[[loadAccumMatmul2:.+]] = vector.insert %[[MATMUL_LOAD1]], %[[loadAccumMatmul1]] [1] + // CHECK-COUNT-2: xegpu.load + + // CHECK: vector.shape_cast + // CHECK-SAME: vector<4x32xf16> to vector<128xf16> + // CHECK: vector.shape_cast + // CHECK-SAME: vector<128xf16> to vector<8x16xf16> - linalg.add ins(%slm_subview, %a_add_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>, memref<16x16xf16, strided<[1024, 1], offset: ?>>) outs(%out_subview : memref<16x16xf16, strided<[1024, 1], offset: ?>>) + // CHECK-COUNT-4: xegpu.load + // CHECK: vector.shape_cast + // CHECK-SAME: vector<4x32xf16> to vector<128xf16> + // CHECK: vector.shape_cast + // CHECK-SAME: vector<128xf16> to vector<8x16xf16> + + // STORE: + // %[[FLAT_MATMUL_RES0:.+]] = vector.shape_cast %[[MATMUL_RES0:.+]] : vector<8x16xf16> to vector<128xf16> + // %[[STORE_TILE0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> + // xegpu.store %[[STORE_TILE0]], %[[ROOT_DESC]] + // %[[STORE_TILE1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> + // xegpu.store %[[STORE_TILE0]], %[[MATMUL_DESC1]] + // CHECK-COUNT-2: xegpu.store + + // %[[FLAT_MATMUL_RES1:.+]] = vector.shape_cast %[[MATMUL_RES1:.+]] : vector<8x16xf16> to vector<128xf16> + // %[[STORE_TILE1_0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> + // xegpu.store %[[STORE_TILE1_0]] + // %[[STORE_TILE1_1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> + // xegpu.store %[[STORE_TILE1_1]] + // CHECK-COUNT-2: xegpu.store + + linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) gpu.terminator } return diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir index 76bebe6c4..4cc15c911 100644 --- a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir @@ -1,6 +1,7 @@ // RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s module @fragment_name { + // This kernel requires using SLM func.func @entry(%0: tensor<64x128xf16>, %1: tensor<128x128xf16>, %2: tensor<64x128xf16>, %res: tensor<64x128xf16>) -> tensor<64x128xf16> { %3 = tensor.empty() : tensor<128x128xf16> %4 = tensor.empty() : tensor<64x128xf16> From e48f7437c3695938971feaf6749d0de7e6b1b22e Mon Sep 17 00:00:00 2001 From: dchigarev Date: Wed, 13 Nov 2024 17:22:04 +0000 Subject: [PATCH 3/7] Move util functions to a separate file Signed-off-by: dchigarev --- include/gc/Transforms/Utils/ValueUtils.h | 17 +++++ lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 91 ++++-------------------- lib/gc/Transforms/Utils/ValueUtils.cpp | 43 +++++++++++ 3 files changed, 75 insertions(+), 76 deletions(-) diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index ef8cf36ce..31d22807d 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -33,6 +33,23 @@ FailureOr> getStaticStrides(Value val); // is not a memref. std::pair getPtrAndOffset(OpBuilder &builder, Value operand); +template +Value createTypedVector(PatternRewriter &rewriter, Location loc, + ArrayRef values, Type elementType) { + mlir::VectorType vectorType = + mlir::VectorType::get({static_cast(values.size())}, elementType); + mlir::DenseElementsAttr denseAttr = + mlir::DenseElementsAttr::get(vectorType, values); + auto vector = + rewriter.create(loc, vectorType, denseAttr) + .getResult(); + return vector; +} + +Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref); + +bool hasSharedMemSpace(mlir::Value memref); + } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 38ab881a8..4920151c2 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -49,78 +49,15 @@ namespace gc { namespace { -// TODO: move these to utils -template -static Value createTypedVector(PatternRewriter &rewriter, Location loc, - ArrayRef values, Type elementType) { - mlir::VectorType vectorType = - mlir::VectorType::get({static_cast(values.size())}, elementType); - mlir::DenseElementsAttr denseAttr = - mlir::DenseElementsAttr::get(vectorType, values); - auto vector = - rewriter.create(loc, vectorType, denseAttr) - .getResult(); - return vector; -} - -static Value createIndexVector(PatternRewriter &rewriter, Location loc, - ArrayRef values) { - return createTypedVector(rewriter, loc, values, rewriter.getIndexType()); -} - -static Value createIndexConstant(PatternRewriter &rewriter, Location loc, - int64_t value) { - return rewriter.create(loc, value); -} - -static Value flattenMemref(PatternRewriter &rewriter, Location loc, - Value srcMemref) { - auto srcType = cast(srcMemref.getType()); - - assert(srcType && "Expected a memref type"); - assert(srcType.getRank() == 2 && "Expected a 2D memref"); - - int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; - - Value offset = rewriter.create(loc, 0); - Value size = rewriter.create(loc, flatSize); - Value stride = rewriter.create(loc, 1); - - // Use memref.reinterpret_cast to flatten the memref - auto flatMemRefType = MemRefType::get({flatSize}, srcType.getElementType(), - nullptr, srcType.getMemorySpace()); - auto flatMemref = - rewriter - .create(loc, flatMemRefType, srcMemref, - offset, size, stride) - .getResult(); - return flatMemref; -} - -static bool hasSharedMemSpace(mlir::Value memref) { - auto type = mlir::dyn_cast(memref.getType()); - if (!type) - return false; - - auto memSpace = type.getMemorySpace(); - if (!memSpace) - return false; - - if (auto gpuAttr = mlir::dyn_cast(memSpace)) - return gpuAttr.getValue() == mlir::gpu::AddressSpace::Private; - - if (auto intAttr = mlir::dyn_cast(memSpace)) - return intAttr.getValue() == - static_cast(mlir::gpu::AddressSpace::Private); - - return false; -} - static Value createFullMask(PatternRewriter &rewriter, Location loc, int64_t size) { - auto maskVal = createIndexConstant(rewriter, loc, 32); + auto maskVal = rewriter.create(loc, 32); mlir::VectorType maskVectorType = mlir::VectorType::get({size}, rewriter.getI1Type()); + // HACK: creating mask vector through this strange op instead of + // simple 'arith.constant dense' to avoid the mask being + // moved out of the GPU kernel (it causes strange behaviour + // when a bit-mask is passed as a kernel parameter). auto res = rewriter.create( loc, maskVectorType, SmallVector({maskVal})); return res.getResult(); @@ -826,8 +763,9 @@ static SmallVector createScatterDescriptorTiles( } int64_t skipPerLoad = memrefStrides[0] * rowsPerLoad; - auto offsetPerLoad = - createIndexVector(rewriter, loc, SmallVector(32, skipPerLoad)); + auto offsetPerLoad = utils::createTypedVector( + rewriter, loc, SmallVector(32, skipPerLoad), + rewriter.getIndexType()); auto offsetVecType = VectorType::get({maxLoadSize}, rewriter.getIndexType()); auto descType = getTensorDescType( @@ -843,7 +781,8 @@ static SmallVector createScatterDescriptorTiles( SmallVector tiles; for (int i = 0; i < numColTiles; i++) { - auto offsetsShift = createIndexVector(rewriter, loc, offsetShiftValues[i]); + auto offsetsShift = utils::createTypedVector( + rewriter, loc, offsetShiftValues[i], rewriter.getIndexType()); auto offsets0 = rewriter.create(loc, blockOffsetV, offsetsShift); @@ -924,7 +863,7 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, } // Scatter descriptors only work with 1D memrefs - src = flattenMemref(rewriter, loc, src); + src = utils::flattenMemref(rewriter, loc, src); return createScatterDescriptorTiles( rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape, @@ -938,7 +877,7 @@ static SmallVector createDescriptorTiles( std::optional> loadOffsets = std::nullopt, int arrayLength = 1, bool transpose = false) { - if (hasSharedMemSpace(src)) { + if (utils::hasSharedMemSpace(src)) { assert(!transpose && "Transpose is not supported for shared memory"); assert(arrayLength == 1 && "Array descriptors are not supported for shared memory"); @@ -1092,8 +1031,8 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc, // Accumulator vector for the current tile (its number of elements equals to // tileShape) HACK: we first create a flat vector of zeros and then cast it // to the 2D shape. Otherwise 'imex::ConvertGPUXToSPIRVPass' fails. - auto accumVector = - createTypedVector(rewriter, loc, accumValues, elementType); + auto accumVector = utils::createTypedVector( + rewriter, loc, accumValues, elementType); accumVector = rewriter.create(loc, accumVectorType, accumVector); @@ -1961,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp, } // Extract SIMD sized sub-tiles - int64_t maxSizeSIMD = hasSharedMemSpace(output) ? 32 : 256; + int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? 32 : 256; int64_t subTileCols = std::min(outputShape[1], maxSizeSIMD); int64_t subTileRows = std::min(outputShape[0], std::max(maxSizeSIMD / subTileCols, 1L)); diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index a3ce3619d..8e3de4213 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -150,5 +151,47 @@ std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { return std::make_pair(alignedPointer, offset); } +Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) { + auto srcType = cast(srcMemref.getType()); + + assert(srcType && "Expected a memref type"); + assert(srcType.getRank() == 2 && "Expected a 2D memref"); + + int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; + + Value offset = rewriter.create(loc, 0); + Value size = rewriter.create(loc, flatSize); + Value stride = rewriter.create(loc, 1); + + // Use memref.reinterpret_cast to flatten the memref + auto flatMemRefType = MemRefType::get({flatSize}, srcType.getElementType(), + nullptr, srcType.getMemorySpace()); + auto flatMemref = + rewriter + .create(loc, flatMemRefType, srcMemref, + offset, size, stride) + .getResult(); + return flatMemref; +} + +bool hasSharedMemSpace(mlir::Value memref) { + auto type = mlir::dyn_cast(memref.getType()); + if (!type) + return false; + + auto memSpace = type.getMemorySpace(); + if (!memSpace) + return false; + + if (auto gpuAttr = mlir::dyn_cast(memSpace)) + return gpuAttr.getValue() == mlir::gpu::AddressSpace::Private; + + if (auto intAttr = mlir::dyn_cast(memSpace)) + return intAttr.getValue() == + static_cast(mlir::gpu::AddressSpace::Private); + + return false; +} + } // namespace utils } // namespace mlir From 884f4baf0c93efaac2f8e0869d3fbafe3a9eca3a Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 14 Nov 2024 11:50:11 +0000 Subject: [PATCH 4/7] clang-format Signed-off-by: dchigarev --- CMakeLists.txt | 3 --- include/gc/Transforms/Utils/ValueUtils.h | 3 +++ include/gc/Utils/Log.h | 1 - lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 5 ++++- lib/gc/Transforms/GPU/Pipeline.cpp | 1 + 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 24ffcd614..66af39632 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,7 +140,4 @@ install(FILES ${PROJECT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake DESTINATION "lib/cmake/${PROJECT_NAME}" ) - -message("C++ Standard: ${CMAKE_CXX_STANDARD}") - ################################################################################ diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index 31d22807d..409f563b1 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -33,6 +33,7 @@ FailureOr> getStaticStrides(Value val); // is not a memref. std::pair getPtrAndOffset(OpBuilder &builder, Value operand); +// Create a 'mlir::vector' constant from a list of values. template Value createTypedVector(PatternRewriter &rewriter, Location loc, ArrayRef values, Type elementType) { @@ -46,8 +47,10 @@ Value createTypedVector(PatternRewriter &rewriter, Location loc, return vector; } +// Flatten a 2D memref to a 1D memref. Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref); +// Return true if the memref has shared memory space. bool hasSharedMemSpace(mlir::Value memref); } // namespace utils diff --git a/include/gc/Utils/Log.h b/include/gc/Utils/Log.h index 6bb765a34..8755cf933 100644 --- a/include/gc/Utils/Log.h +++ b/include/gc/Utils/Log.h @@ -68,7 +68,6 @@ static void debug(const char *fileName, int lineNum, Args... args) { #define gcLogD(...) mlir::gc::log::debug(__FILE__, __LINE__, __VA_ARGS__) #define gcLogE(...) \ mlir::gc::log::log(__FILE__, __LINE__, std::cerr, "ERROR", __VA_ARGS__) -#define gcRunD(...) if (mlir::gc::log::isDebugEnabled(__FILE__)) {__VA_ARGS__;} #endif } // namespace mlir::gc::log diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 4920151c2..3a4c0aa10 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -745,12 +745,14 @@ static SmallVector createScatterDescriptorTiles( "Descriptor tile must be divisible by max load size"); int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxLoadSize; - // How much rows of a tile fit into a single descriptor load + // This indicates how many rows of a single tile (defined by tileSize2D) are + // loaded per single load operation (single load loads exactly 32 elements). int64_t rowsPerLoad = maxLoadSize / tileSize2D[1]; int64_t numColTiles = loadShape2D[1] / tileSize2D[1]; auto memrefType = dyn_cast(flatMemref.getType()); + // compute load offsets for each colTile SmallVector> offsetShiftValues; for (int colTile = 0; colTile < numColTiles; colTile++) { offsetShiftValues.push_back(SmallVector()); @@ -762,6 +764,7 @@ static SmallVector createScatterDescriptorTiles( } } + // This indicates an offset between two loads int64_t skipPerLoad = memrefStrides[0] * rowsPerLoad; auto offsetPerLoad = utils::createTypedVector( rewriter, loc, SmallVector(32, skipPerLoad), diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index 3b4bf5525..b3b1036c3 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -69,6 +69,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addNestedPass(createAllocsToSLM()); pm.addNestedPass(createLinalgToXeGPU( {/*kTile=*/16, /*stages=*/1, /*dpasTiles=*/{8, 16, 16}})); + pm.addPass(createCSEPass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addPass(xegpu::createXeGPUFoldAliasOps()); From e2b6eb08be7dd9e76dd0e2febb4ecf09e93e5186 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 14 Nov 2024 11:56:21 +0000 Subject: [PATCH 5/7] remove unused variables Signed-off-by: dchigarev --- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 3a4c0aa10..e480e5188 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -12,7 +12,6 @@ #include "gc/Transforms/Utils/StructuredOpMatcher.h" #include "gc/Transforms/Utils/ValueUtils.h" -#include "gc/Utils/Log.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -832,7 +831,6 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, auto srcType = cast(src.getType()); assert(srcType.getRank() == 2 && "Expected a 2D memref"); - auto elemByteWidth = srcType.getElementType().getIntOrFloatBitWidth() / 8; SmallVector memrefStrides; Value blockOffset; From 55b0e2fa399bf9b947cdcbbcbfe2e1f2192928a0 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 14 Nov 2024 14:30:02 +0000 Subject: [PATCH 6/7] use constexpr for SLM tile size Signed-off-by: dchigarev --- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 61 ++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index e480e5188..02ef8a7e5 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -62,6 +62,9 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc, return res.getResult(); } +// Max number of elements to load/store from SLM +constexpr int64_t maxSLMTileSize = 32; + // Represents VNNI configuration for an operand. struct VnniConfig { int vnniFactor; @@ -732,21 +735,19 @@ static SmallVector createScatterDescriptorTiles( PatternRewriter &rewriter, Location loc, Value flatMemref, ArrayRef loadShape2D, ArrayRef tileSize2D, ArrayRef memrefStrides, Value blockOffset) { - int64_t maxLoadSize = 32; - assert(memrefStrides.size() == 2 && "Strides must be 2D"); assert(memrefStrides[1] == 1 && "Only row-major strides are supported"); assert(loadShape2D.size() == 2 && "Load shape must be 2D"); - assert(loadShape2D[0] * loadShape2D[1] % maxLoadSize == 0 && + assert(loadShape2D[0] * loadShape2D[1] % maxSLMTileSize == 0 && "Load shape must be divisible by max load size"); assert(tileSize2D.size() == 2 && "Descriptor tile must be 2D"); - assert(maxLoadSize % tileSize2D[1] == 0 && + assert(maxSLMTileSize % tileSize2D[1] == 0 && "Descriptor tile must be divisible by max load size"); - int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxLoadSize; + int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxSLMTileSize; // This indicates how many rows of a single tile (defined by tileSize2D) are // loaded per single load operation (single load loads exactly 32 elements). - int64_t rowsPerLoad = maxLoadSize / tileSize2D[1]; + int64_t rowsPerLoad = maxSLMTileSize / tileSize2D[1]; int64_t numColTiles = loadShape2D[1] / tileSize2D[1]; auto memrefType = dyn_cast(flatMemref.getType()); @@ -757,7 +758,7 @@ static SmallVector createScatterDescriptorTiles( offsetShiftValues.push_back(SmallVector()); for (int i = 0; i < rowsPerLoad; i++) { int64_t offset = i * memrefStrides[0]; - for (int j = 0; j < maxLoadSize / rowsPerLoad; j++) + for (int j = 0; j < maxSLMTileSize / rowsPerLoad; j++) offsetShiftValues[colTile].push_back(offset + j + colTile * tileSize2D[1]); } @@ -769,9 +770,10 @@ static SmallVector createScatterDescriptorTiles( rewriter, loc, SmallVector(32, skipPerLoad), rewriter.getIndexType()); - auto offsetVecType = VectorType::get({maxLoadSize}, rewriter.getIndexType()); + auto offsetVecType = + VectorType::get({maxSLMTileSize}, rewriter.getIndexType()); auto descType = getTensorDescType( - {maxLoadSize}, memrefType.getElementType(), + {maxSLMTileSize}, memrefType.getElementType(), xegpu::ScatterTensorDescAttr::get( rewriter.getContext(), xegpu::MemorySpace::SLM, /*chunkSize=*/1)); @@ -793,8 +795,9 @@ static SmallVector createScatterDescriptorTiles( .create(loc, descType, flatMemref, offsets0) .getResult(); tiles.push_back(desc); - for (int j = maxLoadSize; j < loadShape2D[0] * loadShape2D[1] / numColTiles; - j += maxLoadSize) { + for (int j = maxSLMTileSize; + j < loadShape2D[0] * loadShape2D[1] / numColTiles; + j += maxSLMTileSize) { auto newTile = rewriter .create( loc, descType, tiles.back(), offsetPerLoad) @@ -994,39 +997,37 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc, std::optional vnniConf = std::nullopt, DenseI64ArrayAttr transpose = nullptr, IntegerAttr transpose_bit = nullptr) { - int64_t elementsPerLoad = 32; - // Assume all tiles have the same shape. auto tileType = cast(loadTiles[0].getType()); assert(llvm::all_of(loadTiles, [&](Value tile) { return tile.getType() == tileType; }) && "All load tiles must have the same type."); assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D"); - assert(tileType.getShape()[0] == elementsPerLoad && + assert(tileType.getShape()[0] == maxSLMTileSize && "Scatter tiles must have 32 elements"); assert(!vnniConf && "VNNI not supported for scatter loads"); assert(!transpose && "Transpose is not supported for scatter loads"); assert(!transpose_bit && "Transpose is not supported for scatter loads"); int64_t totalLoadElems = tileType.getShape()[0] * loadTiles.size(); - assert(totalLoadElems % elementsPerLoad == 0 && + assert(totalLoadElems % maxSLMTileSize == 0 && "Total load size must be multiple of 32"); - assert(tileShape[0] * tileShape[1] % elementsPerLoad == 0 && + assert(tileShape[0] * tileShape[1] % maxSLMTileSize == 0 && "Tile shape must be multiple of 32"); - int64_t loadsPerTile = tileShape[0] * tileShape[1] / elementsPerLoad; - int64_t totalNumLoads = totalLoadElems / elementsPerLoad; - auto mask = createFullMask(rewriter, loc, elementsPerLoad); + int64_t loadsPerTile = tileShape[0] * tileShape[1] / maxSLMTileSize; + int64_t totalNumLoads = totalLoadElems / maxSLMTileSize; + auto mask = createFullMask(rewriter, loc, maxSLMTileSize); SmallVector result; auto elementType = tileType.getElementType(); SmallVector accumValues( - loadsPerTile * elementsPerLoad, + loadsPerTile * maxSLMTileSize, dyn_cast(rewriter.getZeroAttr(elementType))); VectorType accumVectorType = - VectorType::get({loadsPerTile, elementsPerLoad}, elementType); - VectorType loadVectorType = VectorType::get({elementsPerLoad}, elementType); + VectorType::get({loadsPerTile, maxSLMTileSize}, elementType); + VectorType loadVectorType = VectorType::get({maxSLMTileSize}, elementType); for (int64_t tileIdx = 0; tileIdx < totalNumLoads; tileIdx += loadsPerTile) { // Accumulator vector for the current tile (its number of elements equals to @@ -1049,7 +1050,7 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc, loc, loadOp.getResult(), accumVector, SmallVector{loadIdx}); } - if (tileShape[1] == elementsPerLoad) { + if (tileShape[1] == maxSLMTileSize) { // No need to reshape the accumulator vector. result.push_back(accumVector); continue; @@ -1115,24 +1116,22 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc, SmallVector &results, ValueRange storeTiles, xegpu::CachePolicyAttr hint) { - int64_t elementsPerStore = 32; - auto tileType = cast(storeTiles[0].getType()); assert(llvm::all_of(storeTiles, [&](Value tile) { return tile.getType() == tileType; }) && "All load tiles must have the same type."); assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D"); - assert(tileType.getShape()[0] == elementsPerStore && + assert(tileType.getShape()[0] == maxSLMTileSize && "Scatter tiles must have 32 elements"); - auto mask = createFullMask(rewriter, loc, elementsPerStore); + auto mask = createFullMask(rewriter, loc, maxSLMTileSize); int64_t descIdx = 0; for (auto vec : results) { auto vecType = dyn_cast(vec.getType()); auto vecShape = vecType.getShape(); assert(vecShape.size() == 2 && "Expected 2D vector"); - assert(vecShape[0] * vecShape[1] % elementsPerStore == 0 && + assert(vecShape[0] * vecShape[1] % maxSLMTileSize == 0 && "Vector shape must be divisible by load size"); // Flatten the vector to 1D @@ -1142,10 +1141,10 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc, vec); // Extract slices of 32 size from 'flatVec' and store them for (int64_t loadChunkIdx = 0; loadChunkIdx < vecShape[0] * vecShape[1]; - loadChunkIdx += elementsPerStore) { + loadChunkIdx += maxSLMTileSize) { auto toStore = rewriter.create( loc, flatVec, /*offsets=*/SmallVector({loadChunkIdx}), - /*sizes=*/SmallVector({elementsPerStore}), + /*sizes=*/SmallVector({maxSLMTileSize}), /*strides=*/SmallVector({1})); rewriter.create(loc, toStore, storeTiles[descIdx], /*mask=*/mask, @@ -1901,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp, } // Extract SIMD sized sub-tiles - int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? 32 : 256; + int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? maxSLMTileSize : 256; int64_t subTileCols = std::min(outputShape[1], maxSizeSIMD); int64_t subTileRows = std::min(outputShape[0], std::max(maxSizeSIMD / subTileCols, 1L)); From 711e4770a465eacd54cc79e468b93277667787a2 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 14 Nov 2024 15:23:25 +0000 Subject: [PATCH 7/7] rename sgMape to descAttr Signed-off-by: dchigarev --- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 02ef8a7e5..86b45466b 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -102,14 +102,14 @@ struct TilesArray { static xegpu::TensorDescType getTensorDescType(llvm::ArrayRef shape, mlir::Type elementType, - std::optional sgMap = std::nullopt) { - if (!sgMap) { + std::optional descAttr = std::nullopt) { + if (!descAttr) { // Assuming default tensor descriptor type (blocked & in global memory). return xegpu::TensorDescType::get(shape, elementType, /*array_length=*/1, /*boundary_check=*/true); } - auto descriptor = sgMap.value(); + auto descriptor = descAttr.value(); if (auto scatterMap = dyn_cast(descriptor)) { auto memSpace = scatterMap.getMemorySpace().getValue(); int64_t chunkSize = scatterMap.getChunkSize().getInt();