@@ -620,12 +620,9 @@ void MultiBatchMatmulOp::getEffects(
620
620
621
621
LogicalResult ScaledDotProductAttentionOp::verify () { return success (); }
622
622
623
- // / Given an N-dimensional tensor x, this method converts
624
- // / softmax(x) to the following sequence of operations:
625
- // /
626
- // / 1. transpose ins[1]
627
- // / 2. matmul ins[0] @ 1
628
- // /
623
+ // / This method converts ScaledDotProductAttention into the following
624
+ // / sequence of operations:
625
+ // / output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
629
626
FailureOr<SmallVector<Value>>
630
627
ScaledDotProductAttentionOp::decomposeOperation (OpBuilder &b) {
631
628
OpBuilder::InsertionGuard guard (b);
@@ -635,6 +632,7 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
635
632
mask = getInputs ()[3 ];
636
633
auto dtype = cast<RankedTensorType>(query.getType ()).getElementType ();
637
634
auto shape = cast<RankedTensorType>(query.getType ()).getShape ();
635
+ float rsqrt_head = 1 / sqrt (shape[3 ]);
638
636
639
637
SmallVector<int64_t > permutation{0 , 1 , 3 , 2 };
640
638
SmallVector<int64_t > transposeShape{shape[0 ], shape[1 ], shape[3 ], shape[2 ]};
@@ -652,16 +650,40 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
652
650
/* inputs=*/ ValueRange{query, transpose->getResult (0 )},
653
651
/* outputs=*/ ValueRange{matmulQKOut.getResult ()});
654
652
653
+ auto mulOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
654
+ // Broadcast the initial value to the output tensor before convolving.
655
+ SmallVector<AffineMap, 4 > indexingMaps;
656
+ indexingMaps.push_back (b.getMultiDimIdentityMap (4 ));
657
+ indexingMaps.push_back (b.getMultiDimIdentityMap (4 ));
658
+ auto mul = b.create <linalg::GenericOp>(
659
+ /* location=*/ loc, matmulQKOut.getResult ().getType (),
660
+ /* inputs=*/ ValueRange{matmulQK->getResult (0 )},
661
+ /* outputs=*/ ValueRange{mulOut.getResult ()}, indexingMaps,
662
+ SmallVector<utils::IteratorType>(4 , utils::IteratorType::parallel),
663
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
664
+ Value constant = b.create <arith::ConstantOp>(
665
+ loc, nestedBuilder.getFloatAttr (dtype, rsqrt_head));
666
+ Value added =
667
+ nestedBuilder.create <arith::MulFOp>(loc, args[0 ], constant);
668
+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
669
+ });
670
+
655
671
auto addOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
656
672
auto add = b.create <linalg::AddOp>(
657
673
/* location=*/ loc, addOut.getResult ().getType (),
658
- /* inputs=*/ ValueRange{matmulQK ->getResult (0 ), mask},
674
+ /* inputs=*/ ValueRange{mul ->getResult (0 ), mask},
659
675
/* outputs=*/ ValueRange{addOut.getResult ()});
660
676
677
+ auto softmaxOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
678
+ auto softmax = b.create <linalg::SoftmaxOp>(
679
+ /* location=*/ loc, softmaxOut.getResult ().getType (),
680
+ /* inputs=*/ add->getResult (0 ),
681
+ /* outputs=*/ softmaxOut.getResult (), 3 );
682
+
661
683
auto matmulVOut = b.create <tensor::EmptyOp>(loc, shape, dtype);
662
684
auto matmulV = b.create <linalgx::MultiBatchMatmulOp>(
663
685
/* location=*/ loc, matmulVOut.getResult ().getType (),
664
- /* inputs=*/ ValueRange{add ->getResult (0 ), value},
686
+ /* inputs=*/ ValueRange{softmax ->getResult (0 ), value},
665
687
/* outputs=*/ ValueRange{matmulVOut.getResult ()});
666
688
return SmallVector<Value>{matmulV.getResults ()[0 ]};
667
689
}
0 commit comments