diff --git a/include/gc/Dialect/Linalgx/LinalgxOps.td b/include/gc/Dialect/Linalgx/LinalgxOps.td index 4491967c3..9bd28c4c2 100644 --- a/include/gc/Dialect/Linalgx/LinalgxOps.td +++ b/include/gc/Dialect/Linalgx/LinalgxOps.td @@ -11,8 +11,33 @@ include "LinalgxDialect.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + // Base class for Linalg dialect ops that do not correspond to library calls. class Linalgx_Op traits = []> : Op; +def Linalgx_ScaledDotProductAttentionOp + : Linalgx_Op<"scaled_dot_product_attention", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "Attention structure."; + let description = [{ + Q, K, V, attention_mask. + Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V. + }]; + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs); + let results = (outs Variadic:$results); + + let hasVerifier = 1; + let assemblyFormat = [{ + attr-dict + `ins` `(` $inputs `:` type($inputs) `)` + `outs` `(` $outputs `:` type($outputs) `)` + (`->` type($results)^)? + }]; +} #endif // LINALGX_OPS \ No newline at end of file diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 79a62f028..77e2263a5 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -32,6 +32,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } +def FlashAttentionConversion + : Pass<"flash-attention-conversion", "func::FuncOp"> { + let summary = "Flash Attention Conversion"; + let description = + [{The pass converts MHA to flash attention implementation.}]; + let dependentDialects = [ + "func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect", + "tensor::TensorDialect" + ]; +} + #ifdef GC_USE_GPU def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { let summary = "Convert linalg dialect to XeGPU dialect."; diff --git a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp index 04eae3657..b22cda771 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp +++ b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp @@ -9,6 +9,7 @@ #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "mlir/IR/OpImplementation.h" +#include //===----------------------------------------------------------------------===// // Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -608,6 +609,80 @@ void MultiBatchMatmulOp::getEffects( getGenericEffectsImpl(effects, cast(getOperation())); } +//===----------------------------------------------------------------------===// +// ScaledDotProductAttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult ScaledDotProductAttentionOp::verify() { return success(); } + +/// This method converts ScaledDotProductAttention into the following +/// sequence of operations: +/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2] +FailureOr> +ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(*this); + Location loc = getLoc(); + Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2], + mask = getInputs()[3]; + auto dtype = cast(query.getType()).getElementType(); + auto shape = cast(query.getType()).getShape(); + float rsqrt_head = 1 / sqrt(shape[3]); + + SmallVector permutation{0, 1, 3, 2}; + SmallVector transposeShape{shape[0], shape[1], shape[3], shape[2]}; + auto transposeOut = b.create(loc, transposeShape, dtype); + auto transpose = b.create( + /*location=*/loc, + /*inputs=*/key, + /*outputs=*/transposeOut, + /*permutation=*/permutation); + + SmallVector matmulQKShape{shape[0], shape[1], shape[2], shape[2]}; + auto matmulQKOut = b.create(loc, matmulQKShape, dtype); + auto matmulQK = b.create( + /*location=*/loc, matmulQKOut.getResult().getType(), + /*inputs=*/ValueRange{query, transpose->getResult(0)}, + /*outputs=*/ValueRange{matmulQKOut.getResult()}); + + auto mulOut = b.create(loc, matmulQKShape, dtype); + // Broadcast the initial value to the output tensor before convolving. + SmallVector indexingMaps; + indexingMaps.push_back(b.getMultiDimIdentityMap(4)); + indexingMaps.push_back(b.getMultiDimIdentityMap(4)); + auto mul = b.create( + /*location=*/loc, matmulQKOut.getResult().getType(), + /*inputs=*/ValueRange{matmulQK->getResult(0)}, + /*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps, + SmallVector(4, utils::IteratorType::parallel), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value constant = b.create( + loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head)); + Value added = + nestedBuilder.create(loc, args[0], constant); + nestedBuilder.create(nestedLoc, added); + }); + + auto addOut = b.create(loc, matmulQKShape, dtype); + auto add = b.create( + /*location=*/loc, addOut.getResult().getType(), + /*inputs=*/ValueRange{mul->getResult(0), mask}, + /*outputs=*/ValueRange{addOut.getResult()}); + + auto softmaxOut = b.create(loc, matmulQKShape, dtype); + auto softmax = b.create( + /*location=*/loc, softmaxOut.getResult().getType(), + /*inputs=*/add->getResult(0), + /*outputs=*/softmaxOut.getResult(), 3); + + auto matmulVOut = b.create(loc, shape, dtype); + auto matmulV = b.create( + /*location=*/loc, matmulVOut.getResult().getType(), + /*inputs=*/ValueRange{softmax->getResult(0), value}, + /*outputs=*/ValueRange{matmulVOut.getResult()}); + return SmallVector{matmulV.getResults()[0]}; +} + /////// Operations corresponding to library calls defined with Tablegen //////// #define GET_OP_CLASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index f60c8cec2..6cb633459 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp Pipeline.cpp TileNamed.cpp + FlashAttentionConversion.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/FlashAttentionConversion.cpp b/lib/gc/Transforms/FlashAttentionConversion.cpp new file mode 100644 index 000000000..e00186f9c --- /dev/null +++ b/lib/gc/Transforms/FlashAttentionConversion.cpp @@ -0,0 +1,390 @@ +//===-- FlashAttentionConversion.cpp ----------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Passes.h" + +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_FLASHATTENTIONCONVERSION +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct FlashAttentionConfig { + int RowBlockSize, ColumnBlockSize; +}; + +static FlashAttentionConfig +getDefaultFlashAttentionConfig(linalgx::ScaledDotProductAttentionOp &sdpaOp) { + // TODO: allow tuning + FlashAttentionConfig cfg; + cfg.RowBlockSize = 64; + cfg.ColumnBlockSize = 64; + return cfg; +} + +static LogicalResult verifyAndAppend(SmallVector &decomposedOps, + Value curVal) { + return success(); +} + +struct MHAToFlashAttention + : public OpRewritePattern { + using OpRewritePattern< + linalgx::ScaledDotProductAttentionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(linalgx::ScaledDotProductAttentionOp sdpaOp, + PatternRewriter &rewriter) const override { + FlashAttentionConfig cfg = getDefaultFlashAttentionConfig(sdpaOp); + Location loc = sdpaOp.getLoc(); + OpBuilder::InsertionGuard guard(rewriter); + auto shape = + dyn_cast(sdpaOp.getOperand(0).getType()).getShape(); + auto dtype = dyn_cast(sdpaOp.getOperand(0).getType()) + .getElementType(); + int64_t seqLen = shape[2], headDim = shape[3]; + auto Q = sdpaOp.getOperand(0), K = sdpaOp.getOperand(1), + V = sdpaOp.getOperand(2), mask = sdpaOp.getOperand(3); + // construct 3 parallel outermost loops for + // batchSize/numHeads/(seqLen/rowBlockSize) + SmallVector destinationTensors; + tensor::getOrCreateDestinations(rewriter, sdpaOp.getLoc(), sdpaOp, + destinationTensors); + SmallVector lbs, ubs, tileSizes; + for (size_t i = 0; i < 3; ++i) { + lbs.push_back(getAsIndexOpFoldResult(rewriter.getContext(), 0)); + ubs.push_back(getAsIndexOpFoldResult(rewriter.getContext(), shape[i])); + tileSizes.push_back(getAsIndexOpFoldResult( + rewriter.getContext(), i == 2 ? cfg.RowBlockSize : 1)); + } + // create forall loop + auto forallOp = rewriter.create( + loc, lbs, ubs, tileSizes, destinationTensors, + /*mapping=*/std::nullopt, + /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); + rewriter.setInsertionPointToEnd(forallOp.getBody()); + SmallVector ivs = forallOp.getInductionVars(); + // inserting body for forall loop + SmallVector offsets; + offsets.push_back(getAsOpFoldResult(ivs[0])); + offsets.push_back(getAsOpFoldResult(ivs[1])); + offsets.push_back(getAsOpFoldResult(ivs[2])); + offsets.push_back(rewriter.getIndexAttr(0)); + SmallVector sizes(4, rewriter.getIndexAttr(1)); + sizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize); + sizes[3] = rewriter.getIndexAttr(headDim); + SmallVector strides(4, rewriter.getIndexAttr(1)); + SmallVector QSliceShape{1, cfg.RowBlockSize, headDim}; + SmallVector KVSliceShape{1, cfg.ColumnBlockSize, headDim}; + Value QSliceShapeOut = + rewriter.create(loc, QSliceShape, dtype); + Value KVSliceShapeOut = + rewriter.create(loc, KVSliceShape, dtype); + Value QSlice = rewriter.create( + loc, cast(QSliceShapeOut.getType()), Q, offsets, + sizes, strides); + Value OSlice = rewriter.create( + loc, destinationTensors[0], offsets, sizes, strides); + Value collapsedOSlice = rewriter.create( + loc, OSlice, SmallVector{{0, 1, 2}, {3}}); + SmallVector blockShape(1, cfg.RowBlockSize); + Value maxSlice = rewriter.create(loc, blockShape, dtype); + Value sumSlice = rewriter.create(loc, blockShape, dtype); + Value zero = + rewriter.create(loc, rewriter.getZeroAttr(dtype)); + Value minusInf = rewriter.create( + loc, rewriter.getFloatAttr( + dtype, APFloat::getLargest( + cast(dtype).getFloatSemantics(), true))); + Value maxSliceFilled = + rewriter.create(loc, minusInf, maxSlice).getResult(0); + Value sumSliceFilled = + rewriter.create(loc, zero, sumSlice).getResult(0); + Value collapsedOSliceFilled = + rewriter.create(loc, zero, collapsedOSlice) + .getResult(0); + // create the innermost for loop for columnBlock + SmallVector innermostDestinationTensors{ + collapsedOSliceFilled, maxSliceFilled, sumSliceFilled}; + auto columnBlockLoop = rewriter.create( + loc, + getValueOrCreateConstantIndexOp( + rewriter, loc, getAsIndexOpFoldResult(rewriter.getContext(), 0UL)), + getValueOrCreateConstantIndexOp( + rewriter, loc, + getAsIndexOpFoldResult(rewriter.getContext(), seqLen)), + getValueOrCreateConstantIndexOp( + rewriter, loc, + getAsIndexOpFoldResult(rewriter.getContext(), cfg.ColumnBlockSize)), + innermostDestinationTensors, + [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, + ValueRange /*iterArgs*/) {}); + ivs.push_back(columnBlockLoop.getInductionVar()); + rewriter.setInsertionPointToStart(columnBlockLoop.getBody()); + // innermost computations + Value prevOSlice = columnBlockLoop.getRegionIterArgs()[0], + prevMaxSlice = columnBlockLoop.getRegionIterArgs()[1], + prevSumSlice = columnBlockLoop.getRegionIterArgs()[2]; + // adjust offsets and sizes + offsets[2] = getAsOpFoldResult(ivs[3]); + sizes[2] = rewriter.getIndexAttr(cfg.ColumnBlockSize); + Value KSlice = rewriter.create( + loc, cast(KVSliceShapeOut.getType()), K, offsets, + sizes, strides); + Value VSlice = rewriter.create( + loc, cast(KVSliceShapeOut.getType()), V, offsets, + sizes, strides); + offsets[2] = getAsOpFoldResult(ivs[2]); + offsets[3] = getAsOpFoldResult(ivs[3]); + sizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize); + sizes[3] = rewriter.getIndexAttr(cfg.ColumnBlockSize); + SmallVector maskSliceShape{cfg.RowBlockSize, cfg.ColumnBlockSize}; + Value QKShapeOut = + rewriter.create(loc, maskSliceShape, dtype); + Value maskSlice = rewriter.create( + loc, cast(QKShapeOut.getType()), mask, offsets, sizes, + strides); + // transpose K + SmallVector transposedShape{1, headDim, cfg.RowBlockSize}; + Value transposedShapeOut = + rewriter.create(loc, transposedShape, dtype); + SmallVector transPerm{0, 2, 1}; + Value transposedKSlice = rewriter + .create( + loc, KSlice, transposedShapeOut, transPerm) + ->getResult(0); + // matmul QK + Value matmulQKOutFilled = + rewriter.create(loc, zero, QKShapeOut).getResult(0); + Value matmulQK = rewriter + .create( + loc, matmulQKOutFilled.getType(), + ValueRange{QSlice, transposedKSlice}, + ValueRange{matmulQKOutFilled}) + .getResult(0); + // scale & add mask + float rsqrtHead = 1 / sqrt(headDim); + SmallVector indexingMaps; + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(2)); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(2)); + Value mul = + rewriter + .create( + loc, QKShapeOut.getType(), ValueRange{matmulQK}, + ValueRange{QKShapeOut}, indexingMaps, + SmallVector(2, + utils::IteratorType::parallel), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value constant = nestedBuilder.create( + loc, nestedBuilder.getFloatAttr(dtype, rsqrtHead)); + Value scaled = nestedBuilder.create( + loc, args[0], constant); + nestedBuilder.create(nestedLoc, scaled); + }) + .getResult(0); + Value add = rewriter + .create(loc, QKShapeOut.getType(), + ValueRange{mul, maskSlice}, + ValueRange{QKShapeOut}) + .getResult(0); + // tiling softmax + SmallVector reducedShape{cfg.RowBlockSize}; + Value reducedShapeOut = + rewriter.create(loc, reducedShape, dtype); + Value reduceMaxFilled = + rewriter.create(loc, minusInf, reducedShapeOut) + .getResult(0); + Value curMaxSlice = + rewriter + .create( + loc, ValueRange{add}, ValueRange{reduceMaxFilled}, 1, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value result = nestedBuilder.create( + nestedLoc, blockArgs[0], blockArgs[1]); + nestedBuilder.create(nestedLoc, result); + }) + .getResult(0); + Value newMaxSlice = + rewriter + .create(loc, reducedShapeOut.getType(), + ValueRange{prevMaxSlice, curMaxSlice}, + ValueRange{reducedShapeOut}) + .getResult(0); + Value newMaxSliceBroadcasted = + rewriter + .create(loc, newMaxSlice, QKShapeOut, + SmallVector{1}) + .getResults()[0]; + Value sub = + rewriter + .create(loc, QKShapeOut.getType(), + ValueRange{add, newMaxSliceBroadcasted}, + ValueRange{QKShapeOut}) + .getResult(0); + Value PSlice = + rewriter + .create(loc, QKShapeOut.getType(), ValueRange{sub}, + ValueRange{QKShapeOut}) + .getResult(0); + Value reduceSumFilled = + rewriter.create(loc, zero, reducedShapeOut) + .getResult(0); + Value curSumSlice = + rewriter + .create( + loc, ValueRange{PSlice}, ValueRange{reduceSumFilled}, 1, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value result = nestedBuilder.create( + nestedLoc, blockArgs[0], blockArgs[1]); + nestedBuilder.create(nestedLoc, result); + }) + .getResult(0); + Value maxDiff = + rewriter + .create(loc, reducedShapeOut.getType(), + ValueRange{prevMaxSlice, newMaxSlice}, + ValueRange{reducedShapeOut}) + .getResult(0); + Value expMaxDiff = rewriter + .create( + loc, reducedShapeOut.getType(), + ValueRange{maxDiff}, ValueRange{reducedShapeOut}) + .getResult(0); + Value rescaledPrevSumSlice = + rewriter + .create(loc, reducedShapeOut.getType(), + ValueRange{prevSumSlice, expMaxDiff}, + ValueRange{reducedShapeOut}) + .getResult(0); + Value newSumSlice = rewriter + .create( + loc, reducedShapeOut.getType(), + ValueRange{curSumSlice, rescaledPrevSumSlice}, + ValueRange{reducedShapeOut}) + .getResult(0); + Value newSumSliceRecip = + rewriter + .create(loc, reducedShapeOut.getType(), + ValueRange{newSumSlice}, + ValueRange{reducedShapeOut}) + .getResult(0); + SmallVector VShape{cfg.RowBlockSize, headDim}; + Value VShapeOut = rewriter.create(loc, VShape, dtype); + Value matmulVOutFilled = + rewriter.create(loc, zero, VShapeOut).getResult(0); + SmallVector expandedPSliceShape{ + rewriter.getIndexAttr(1), rewriter.getIndexAttr(cfg.RowBlockSize), + rewriter.getIndexAttr(cfg.ColumnBlockSize)}; + Value expandedPSliceShapeOut = + rewriter.create(loc, expandedPSliceShape, dtype); + Value expandedPSlice = rewriter.create( + loc, expandedPSliceShapeOut.getType(), PSlice, + SmallVector{{0, 1}, {2}}, expandedPSliceShape); + Value matmulV = rewriter + .create( + loc, matmulVOutFilled.getType(), + ValueRange{expandedPSlice, VSlice}, + ValueRange{matmulVOutFilled}) + .getResult(0); + Value newSumSliceRecipBroadcasted = + rewriter + .create(loc, newSumSliceRecip, VShapeOut, + SmallVector{1}) + .getResults()[0]; + Value rescaledPrevSumSliceBroadcasted = + rewriter + .create(loc, rescaledPrevSumSlice, VShapeOut, + SmallVector{1}) + .getResults()[0]; + Value rescaledMatmulV = + rewriter + .create( + loc, matmulVOutFilled.getType(), + ValueRange{matmulV, newSumSliceRecipBroadcasted}, + ValueRange{matmulVOutFilled}) + .getResult(0); + Value sumSliceQuotient = + rewriter + .create(loc, matmulVOutFilled.getType(), + ValueRange{rescaledPrevSumSliceBroadcasted, + newSumSliceRecipBroadcasted}, + ValueRange{matmulVOutFilled}) + .getResult(0); + Value rescaledOSlice = + rewriter + .create(loc, matmulVOutFilled.getType(), + ValueRange{prevOSlice, sumSliceQuotient}, + ValueRange{matmulVOutFilled}) + .getResult(0); + Value newOSlice = + rewriter + .create(loc, VShapeOut.getType(), + ValueRange{rescaledOSlice, rescaledMatmulV}, + ValueRange{VShapeOut}) + .getResult(0); + // yield all the results of the innermost loop. + rewriter.create( + loc, ValueRange{newOSlice, newMaxSlice, newSumSlice}); + // yield parallel loop results + auto innermostLoopResults = columnBlockLoop->getResults(); + Value OSliceFinal = innermostLoopResults[0]; + SmallVector outputOffsets; + outputOffsets.push_back(getAsOpFoldResult(ivs[0])); + outputOffsets.push_back(getAsOpFoldResult(ivs[1])); + outputOffsets.push_back(getAsOpFoldResult(ivs[2])); + outputOffsets.push_back(rewriter.getIndexAttr(0)); + SmallVector outputSizes(4, rewriter.getIndexAttr(1)); + outputSizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize); + outputSizes[3] = rewriter.getIndexAttr(headDim); + // Add the scf.forall.in_parallel operations for the forall op + rewriter.setInsertionPointToEnd(forallOp.getBody()); + auto term = rewriter.create(loc); + rewriter.setInsertionPointToStart(term.getBody()); + rewriter.create( + loc, OSliceFinal, forallOp.getRegionIterArgs()[0], outputOffsets, + outputSizes, strides); + rewriter.replaceOp(sdpaOp, forallOp->getResults()); + return success(); + } +}; + +struct FlashAttentionConversion + : public impl::FlashAttentionConversionBase { +public: + void runOnOperation() final { + auto &ctx = getContext(); + IRRewriter rewriter(&ctx); + RewritePatternSet patterns(&ctx); + patterns.add(patterns.getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 29a143835..9061da4f0 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -106,6 +106,7 @@ void populateCPURuntimePasses(mlir::OpPassManager &pm) { } void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) { + pm.addPass(createLowerAffinePass()); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createConvertSCFToCFPass()); pm.addPass(cpuruntime::createCPURuntimeToLLVM()); diff --git a/test/gc/Transform/flashAttention.mlir b/test/gc/Transform/flashAttention.mlir new file mode 100644 index 000000000..284ab68d2 --- /dev/null +++ b/test/gc/Transform/flashAttention.mlir @@ -0,0 +1,31 @@ +// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void +// | FileCheck --allow-empty + +func.func @flash_attention(%arg0: tensor<56x16x384x64xf32>, %arg1: tensor<56x16x384x64xf32>, %arg2: tensor<56x16x384x64xf32>, %arg3: tensor<56x16x384x384xf32>) -> tensor<56x16x384x64xf32> { + %0 = tensor.empty() : tensor<56x16x384x64xf32> + %1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x384xf32>) outs(%0 : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32> + return %1 : tensor<56x16x384x64xf32> +} + +func.func @main() { + %cst = arith.constant 4.000000e+00 : f32 + + %QKVShape = tensor.empty() : tensor<56x16x384x64xf32> + %maskShape = tensor.empty() : tensor<56x16x384x384xf32> + + %Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32> + %K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32> + %V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32> + %mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<56x16x384x384xf32>) -> tensor<56x16x384x384xf32> + + %out = func.call @flash_attention(%Q, %K, %V, %mask) : + (tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x384xf32>) + -> (tensor<56x16x384x64xf32>) + + %idx = arith.constant 0 : index + %val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<56x16x384x64xf32> + cpuruntime.printf "output[0, 0, 0, 0]: %f\n" %val : f32 + + return +} +// CHECK: output[0, 0, 0]: 4.0