Skip to content

Commit 97f85b4

Browse files
committed
fix flash attention decomposition
1 parent 9b02c96 commit 97f85b4

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

lib/gc/Dialect/Linalgx/LinalgxOps.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -620,12 +620,9 @@ void MultiBatchMatmulOp::getEffects(
620620

621621
LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }
622622

623-
/// Given an N-dimensional tensor x, this method converts
624-
/// softmax(x) to the following sequence of operations:
625-
///
626-
/// 1. transpose ins[1]
627-
/// 2. matmul ins[0] @ 1
628-
///
623+
/// This method converts ScaledDotProductAttention into the following
624+
/// sequence of operations:
625+
/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
629626
FailureOr<SmallVector<Value>>
630627
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
631628
OpBuilder::InsertionGuard guard(b);
@@ -635,6 +632,7 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
635632
mask = getInputs()[3];
636633
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
637634
auto shape = cast<RankedTensorType>(query.getType()).getShape();
635+
float rsqrt_head = 1 / sqrt(shape[3]);
638636

639637
SmallVector<int64_t> permutation{0, 1, 3, 2};
640638
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
@@ -652,16 +650,40 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
652650
/*inputs=*/ValueRange{query, transpose->getResult(0)},
653651
/*outputs=*/ValueRange{matmulQKOut.getResult()});
654652

653+
auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
654+
// Broadcast the initial value to the output tensor before convolving.
655+
SmallVector<AffineMap, 4> indexingMaps;
656+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
657+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
658+
auto mul = b.create<linalg::GenericOp>(
659+
/*location=*/loc, matmulQKOut.getResult().getType(),
660+
/*inputs=*/ValueRange{matmulQK->getResult(0)},
661+
/*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps,
662+
SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel),
663+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
664+
Value constant = b.create<arith::ConstantOp>(
665+
loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head));
666+
Value added =
667+
nestedBuilder.create<arith::MulFOp>(loc, args[0], constant);
668+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
669+
});
670+
655671
auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
656672
auto add = b.create<linalg::AddOp>(
657673
/*location=*/loc, addOut.getResult().getType(),
658-
/*inputs=*/ValueRange{matmulQK->getResult(0), mask},
674+
/*inputs=*/ValueRange{mul->getResult(0), mask},
659675
/*outputs=*/ValueRange{addOut.getResult()});
660676

677+
auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
678+
auto softmax = b.create<linalg::SoftmaxOp>(
679+
/*location=*/loc, softmaxOut.getResult().getType(),
680+
/*inputs=*/add->getResult(0),
681+
/*outputs=*/softmaxOut.getResult(), 3);
682+
661683
auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
662684
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
663685
/*location=*/loc, matmulVOut.getResult().getType(),
664-
/*inputs=*/ValueRange{add->getResult(0), value},
686+
/*inputs=*/ValueRange{softmax->getResult(0), value},
665687
/*outputs=*/ValueRange{matmulVOut.getResult()});
666688
return SmallVector<Value>{matmulV.getResults()[0]};
667689
}

0 commit comments

Comments
 (0)