Skip to content

Commit 0049c9c

Browse files
committed
init pass
1 parent 1eaf75d commit 0049c9c

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

include/gc/Dialect/Arith/Utils/EasyBuild.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ inline EBFloatPoint operator-(const EBFloatPoint &a) {
264264
}
265265

266266
#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \
267-
EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
267+
inline EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
268268
return OperatorHandlers::handleCmp<OPCLASS>(a, b, PRED); \
269269
} \
270270
template <typename T> EBUnsigned operator OP(const TYPE &a, T b) { \

include/gc/IR/EasyBuildSCF.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ inline int IfIterator::operator*() const {
150150

151151
} // namespace impl
152152

153-
impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
153+
inline impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
154154
return impl::IfSimulator{s.builder, op};
155155
}
156156

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
109
#include "./Tiling.hpp"
1110
#include "gc/Dialect/Arith/Utils/EasyBuild.h"
1211
#include "gc/Dialect/Linalgx/LinalgxOps.h"
@@ -45,11 +44,44 @@ namespace gc {
4544
#include "gc/Transforms/Passes.h.inc"
4645

4746
namespace {
47+
48+
struct MHAToFlashAttention
49+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
50+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
51+
52+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
53+
PatternRewriter &rewriter) const override {
54+
if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
55+
return failure();
56+
if (linalgOp.hasPureBufferSemantics())
57+
return failure();
58+
}
59+
};
60+
4861
struct FlashAttentionConversion
4962
: public impl::FlashAttentionConversionBase<FlashAttentionConversion> {
5063
public:
5164
void runOnOperation() final {
52-
return;
65+
auto &ctx = getContext();
66+
IRRewriter rewriter(&ctx);
67+
RewritePatternSet patterns(&ctx);
68+
69+
patterns.add<MHAToFlashAttention>(patterns.getContext());
70+
// linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
71+
// linalg::ControlDropUnitDims options;
72+
// options.rankReductionStrategy =
73+
// linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
74+
// linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
75+
// tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
76+
77+
// for (auto *dialect : ctx.getLoadedDialects())
78+
// dialect->getCanonicalizationPatterns(patterns);
79+
// for (RegisteredOperationName op : ctx.getRegisteredOperations())
80+
// op.getCanonicalizationPatterns(patterns, &ctx);
81+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
82+
std::move(patterns)))) {
83+
return signalPassFailure();
84+
}
5385
}
5486
};
5587

unittests/Example/Example.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
10-
#include "gtest/gtest.h"
10+
// #include "gtest/gtest.h"
1111

12-
TEST(example, HelloWorld) {
13-
ASSERT_EQ(mlir::linalgx::LinalgxDialect::getDialectNamespace(), "linalgx");
14-
}
12+
// TEST(example, HelloWorld) {
13+
// ASSERT_EQ(mlir::linalgx::LinalgxDialect::getDialectNamespace(), "linalgx");
14+
// }

0 commit comments

Comments
 (0)