Skip to content

Commit f75d46a

Browse files
authored
[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (#65621)
This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on #65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
1 parent 293ae0b commit f75d46a

File tree

7 files changed

+418
-8
lines changed

7 files changed

+418
-8
lines changed

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
namespace mlir {
2121
namespace arm_sme {
2222

23+
constexpr unsigned MinStreamingVectorLengthInBits = 128;
24+
2325
/// Return minimum number of elements for the given element `type` in
2426
/// a vector of SVL bits.
2527
unsigned getSMETileSliceMinNumElts(Type type);

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering
361361
}
362362
};
363363

364+
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
365+
///
366+
/// Example:
367+
///
368+
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
369+
/// : vector<[4]xf32>, vector<[4]xf32>
370+
///
371+
/// is converted to:
372+
///
373+
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
374+
/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
375+
/// vector<[4]xf32>) -> ()
376+
///
377+
/// Currently only supports FMOPA and BFMOPA (non-widening).
378+
struct VectorOuterProductToArmSMELowering
379+
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
380+
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
381+
382+
LogicalResult
383+
matchAndRewrite(vector::OuterProductOp outerProductOp,
384+
vector::OuterProductOp::Adaptor adaptor,
385+
ConversionPatternRewriter &rewriter) const override {
386+
auto isSupportedType = [](VectorType vectorType) {
387+
// TODO: the FP outer product instruction variants are predicated on
388+
// different features [1]:
389+
//
390+
// * FMOPA (non-widening)
391+
// * half-precision - +sme2p1,+sme-f16f16
392+
// * single-precision - +sme
393+
// * double-precision - +sme-f64f64
394+
// * BFMOPA
395+
// * half-precision - +sme2p1,+b16b16
396+
//
397+
// It should be possible to control lowering based on target features.
398+
// [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
399+
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
400+
return false;
401+
402+
auto elementType = vectorType.getElementType();
403+
404+
if (!elementType.isF16() && !elementType.isBF16() &&
405+
!elementType.isF32() && !elementType.isF64())
406+
return false;
407+
408+
unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
409+
vectorType.getElementTypeBitWidth();
410+
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
411+
return false;
412+
413+
return true;
414+
};
415+
416+
auto resultVectorType = outerProductOp.getResultVectorType();
417+
if (!isSupportedType(resultVectorType))
418+
return outerProductOp.emitError("unsupported type");
419+
420+
vector::CombiningKind kind = outerProductOp.getKind();
421+
if (kind != vector::CombiningKind::ADD)
422+
// TODO: support subtract.
423+
return outerProductOp.emitError("unsupported kind");
424+
425+
auto maskableOp =
426+
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
427+
if (maskableOp.isMasked())
428+
// TODO: support masking.
429+
return outerProductOp.emitError("masking is currently unsupported");
430+
431+
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
432+
// AXPY operation not suited for SME.
433+
return failure();
434+
435+
auto loc = outerProductOp.getLoc();
436+
437+
Value acc = outerProductOp.getAcc();
438+
if (!acc)
439+
// Initalize accumulator with zero.
440+
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
441+
442+
unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
443+
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
444+
loc, rewriter.getIntegerType(elementWidth), acc);
445+
446+
// Create all active predicate mask.
447+
auto one = rewriter.create<arith::ConstantOp>(
448+
loc, rewriter.getI1Type(),
449+
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
450+
auto predTy =
451+
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
452+
/*scalableDims=*/{true});
453+
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
454+
455+
auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
456+
457+
// Create 'arm_sme.intr.mopa' outer product intrinsic.
458+
rewriter.create<arm_sme::aarch64_sme_mopa>(
459+
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
460+
outerProductOp.getRhs());
461+
462+
// Create `CastTileToVectorOp` to use as the output.
463+
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
464+
outerProductOp, resultVectorType, tileId);
465+
466+
return success();
467+
}
468+
};
469+
364470
} // namespace
365471

366472
void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
374480
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
375481
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
376482
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
377-
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
483+
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
484+
arm_sme::aarch64_sme_za_disable>();
378485
target.addLegalOp<GetTileID>();
486+
target.addIllegalOp<vector::OuterProductOp>();
379487

380488
// Mark 'func.func' ops as legal if either:
381489
// 1. no 'arm_za' function attribute is present.
@@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
405513
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
406514
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
407515
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
408-
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
409-
LoadTileSliceToArmSMELowering,
410-
MoveVectorToTileSliceToArmSMELowering>(converter);
516+
patterns
517+
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
518+
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
519+
VectorOuterProductToArmSMELowering>(converter);
411520
}

mlir/lib/Dialect/ArmSME/Utils/Utils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
using namespace mlir;
1818
using namespace mlir::arm_sme;
1919

20-
static constexpr unsigned MinStreamingVectorLengthInBits = 128;
21-
2220
unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
2321
assert(isValidSMETileElementType(type) && "invalid tile type!");
2422
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,11 +1122,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
11221122

11231123
LogicalResult matchAndRewrite(vector::OuterProductOp op,
11241124
PatternRewriter &rewriter) const override {
1125+
VectorType resType = op.getResultVectorType();
1126+
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1127+
return failure();
1128+
11251129
auto loc = op.getLoc();
11261130

11271131
VectorType lhsType = op.getOperandVectorTypeLHS();
11281132
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1129-
VectorType resType = op.getResultVectorType();
11301133
Type eltType = resType.getElementType();
11311134
bool isInt = isa<IntegerType, IndexType>(eltType);
11321135
Value acc = op.getAcc();

mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// vector.transfer_write
5+
//===----------------------------------------------------------------------===//
26

37
// CHECK-LABEL: @transfer_write_2d_zero_i8(
48
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
3337
return
3438
}
3539

40+
//===----------------------------------------------------------------------===//
41+
// vector.load
42+
//===----------------------------------------------------------------------===//
43+
3644
// -----
3745

3846
// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
232240
return %tile : vector<[1]x[1]xi128>
233241
}
234242

243+
//===----------------------------------------------------------------------===//
244+
// vector.store
245+
//===----------------------------------------------------------------------===//
246+
235247
// -----
236248

237249
// CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
391403
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
392404
return
393405
}
406+
407+
//===----------------------------------------------------------------------===//
408+
// vector.outerproduct
409+
//===----------------------------------------------------------------------===//
410+
411+
// -----
412+
413+
// CHECK-LABEL: @vector_outerproduct_add_f16
414+
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
415+
func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
416+
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
417+
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
418+
// CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
419+
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
420+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
421+
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
422+
}
423+
424+
// -----
425+
426+
// CHECK-LABEL: @vector_outerproduct_add_bf16
427+
func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
428+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
429+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
430+
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
431+
}
432+
433+
// -----
434+
435+
// CHECK-LABEL: @vector_outerproduct_add_f32
436+
func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
437+
// CHECK-NOT: arith.extui
438+
// CHECK-NOT: arith.trunci
439+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
440+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
441+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
442+
}
443+
444+
// -----
445+
446+
// CHECK-LABEL: @vector_outerproduct_add_f64
447+
func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
448+
// CHECK: arith.trunci {{.*}} : i64 to i32
449+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
450+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
451+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
452+
}
453+
454+
// -----
455+
456+
// CHECK-LABEL: @vector_outerproduct_no_accumulator
457+
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
458+
// CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
459+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
460+
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
461+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
462+
}
463+
464+
// -----
465+
466+
// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
467+
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
468+
// CHECK-NOT: arm_sme
469+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
470+
return %0 : vector<[2]xf64>
471+
}
472+
473+
// -----
474+
475+
func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
476+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
477+
// expected-error@+1 {{unsupported type}}
478+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
479+
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
480+
}
481+
482+
// -----
483+
484+
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
485+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
486+
// expected-error@+1 {{unsupported kind}}
487+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
488+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
489+
}
490+
491+
// -----
492+
493+
func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
494+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
495+
// expected-error@+1 {{masking is currently unsupported}}
496+
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
497+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
498+
}

0 commit comments

Comments
 (0)