diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..7226642daf861 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE) +add_public_tablegen_target(MLIRArmSVEPassIncGen) + +add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h new file mode 100644 index 0000000000000..66f30a67cb05b --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h @@ -0,0 +1,36 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::arm_sve { + +#define GEN_PASS_DECL +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" + +/// Pass to legalize Arm SVE vector storage. +std::unique_ptr createLegalizeVectorStoragePass(); + +/// Collect a set of patterns to legalize Arm SVE vector storage. +void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" + +} // namespace mlir::arm_sve + +#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td new file mode 100644 index 0000000000000..d7cb309db5253 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td @@ -0,0 +1,68 @@ +//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def LegalizeVectorStorage + : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> { + let summary = "Ensures stores of SVE vector types will be legal"; + let description = [{ + This pass ensures that loads, stores, and allocations of SVE vector types + will be legal in the LLVM backend. It does this at the memref level, so this + pass must be applied before lowering all the way to LLVM. + + This pass currently addresses two issues. + + ## Loading and storing predicate types + + It is only legal to load/store predicate types equal to (or greater than) a + full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller + predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full + predicate type (referred to as a `svbool`) before and after storing and + loading respectively. This pass does this by widening allocations and + inserting conversion intrinsics. Note: Non-powers-of-two masks (e.g. + `vector<[7]xi1>`), which are not SVE predicates, are ignored. + + For example: + + ```mlir + %alloca = memref.alloca() : memref> + %mask = vector.constant_mask [4] : vector<[4]xi1> + memref.store %mask, %alloca[] : memref> + %reload = memref.load %alloca[] : memref> + ``` + Becomes: + ```mlir + %alloca = memref.alloca() {alignment = 1 : i64} : memref> + %mask = vector.constant_mask [4] : vector<[4]xi1> + %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> + memref.store %svbool, %alloca[] : memref> + %reload_svbool = memref.load %alloca[] : memref> + %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> + ``` + + ## Relax alignments for SVE vector allocas + + The storage for SVE vector types only needs to have an alignment that + matches the element type (for example 4 byte alignment for `f32`s). However, + the LLVM backend currently defaults to aligning to `base size` x + `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this + results in 8 x 4 = 32-byte alignment, but the backend only supports up to + 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller + alignment prevents this issue. + }]; + let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()"; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", "vector::VectorDialect", + "arm_sve::ArmSVEDialect"]; +} + +#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 5489a13a8040b..7301905954f56 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" #include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -82,6 +83,7 @@ inline void registerAllPasses() { transform::registerTransformPasses(); vector::registerVectorPasses(); arm_sme::registerArmSMEPasses(); + arm_sve::registerArmSVEPasses(); // Dialect pipelines bufferization::registerBufferizationPipelines(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index 2f1c43fae240d..a70c489a51fea 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_dialect_library(MLIRArmSVETransforms LegalizeForLLVMExport.cpp + LegalizeVectorStorage.cpp DEPENDS MLIRArmSVEConversionsIncGen + MLIRArmSVEPassIncGen LINK_LIBS PUBLIC MLIRArmSVEDialect diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp new file mode 100644 index 0000000000000..bee1f3659753b --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -0,0 +1,338 @@ +//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::arm_sve { +#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" +} // namespace mlir::arm_sve + +using namespace mlir; +using namespace mlir::arm_sve; + +// A tag to mark unrealized_conversions produced by this pass. This is used to +// detect IR this pass failed to completely legalize, and report an error. +// If everything was successfully legalized, no tagged ops will remain after +// this pass. +constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__"); + +/// Definitions: +/// +/// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE +/// predicate registers. A full predicate is the smallest quantity that can be +/// loaded/stored. +/// +/// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing +/// dimension matches the size of a legal SVE vector size (such as +/// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than +/// a svbool). + +namespace { + +/// Checks if a vector type is a SVE mask [2]. +bool isSVEMaskType(VectorType type) { + return type.getRank() > 0 && type.getElementType().isInteger(1) && + type.getScalableDims().back() && type.getShape().back() < 16 && + llvm::isPowerOf2_32(type.getShape().back()) && + !llvm::is_contained(type.getScalableDims().drop_back(), true); +} + +VectorType widenScalableMaskTypeToSvbool(VectorType type) { + assert(isSVEMaskType(type)); + return VectorType::Builder(type).setDim(type.getRank() - 1, 16); +} + +/// A helper for cloning an op and replacing it will a new version, updated by a +/// callback. +template +void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op, + TLegalizerCallback callback) { + // Clone the previous op to preserve any properties/attributes. + auto newOp = op.clone(); + rewriter.insert(newOp); + rewriter.replaceOp(op, callback(newOp)); +} + +/// A helper for cloning an op and replacing it with a new version, updated by a +/// callback, and an unrealized conversion back to the type of the replaced op. +template +void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, + TLegalizerCallback callback) { + replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) { + // Mark our `unrealized_conversion_casts` with a pass label. + return rewriter.create( + op.getLoc(), TypeRange{op.getResult().getType()}, + ValueRange{callback(newOp)}, + NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag), + rewriter.getUnitAttr())); + }); +} + +/// Extracts the widened SVE memref value (that's legal to store/load) from the +/// `unrealized_conversion_cast`s added by this pass. +static FailureOr getSVELegalizedMemref(Value illegalMemref) { + Operation *definingOp = illegalMemref.getDefiningOp(); + if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag)) + return failure(); + auto unrealizedConversion = + llvm::cast(definingOp); + return unrealizedConversion.getOperand(0); +} + +/// The default alignment of an alloca in LLVM may request overaligned sizes for +/// SVE types, which will fail during stack frame allocation. This rewrite +/// explicitly adds a reasonable alignment to allocas of scalable types. +struct RelaxScalableVectorAllocaAlignment + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, + PatternRewriter &rewriter) const override { + auto memrefElementType = allocaOp.getType().getElementType(); + auto vectorType = llvm::dyn_cast(memrefElementType); + if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment()) + return failure(); + + // Set alignment based on the defaults for SVE vectors and predicates. + unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16; + allocaOp.setAlignment(aligment); + + return success(); + } +}; + +/// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_ +/// to load/store) with a wider allocation of svbool (_legal_ to load/store) +/// followed by a tagged unrealized conversion to the original type. +/// +/// Example +/// ``` +/// %alloca = memref.alloca() : memref> +/// ``` +/// is rewritten into: +/// ``` +/// %widened = memref.alloca() {alignment = 1 : i64} : memref> +/// %alloca = builtin.unrealized_conversion_cast %widened +/// : memref> to memref> +/// {__arm_sve_legalize_vector_storage__} +/// ``` +template +struct LegalizeSVEMaskAllocation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp, + PatternRewriter &rewriter) const override { + auto vectorType = + llvm::dyn_cast(allocLikeOp.getType().getElementType()); + + if (!vectorType || !isSVEMaskType(vectorType)) + return failure(); + + // Replace this alloc-like op of an SVE mask [2] with one of a (storable) + // svbool mask [1]. A temporary unrealized_conversion_cast is added to the + // old type to allow local rewrites. + replaceOpWithUnrealizedConversion( + rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) { + newAllocLikeOp.getResult().setType( + llvm::cast(newAllocLikeOp.getType().cloneWith( + {}, widenScalableMaskTypeToSvbool(vectorType)))); + return newAllocLikeOp; + }); + + return success(); + } +}; + +/// Replaces vector.type_casts of unrealized conversions to SVE predicate memref +/// types that are _illegal_ to load/store from (!= svbool [1]), with type casts +/// of memref types that are _legal_ to load/store, followed by unrealized +/// conversions. +/// +/// Example: +/// ``` +/// %alloca = builtin.unrealized_conversion_cast %widened +/// : memref> to memref> +/// {__arm_sve_legalize_vector_storage__} +/// %cast = vector.type_cast %alloca +/// : memref> to memref<3xvector<[8]xi1>> +/// ``` +/// is rewritten into: +/// ``` +/// %widened_cast = vector.type_cast %widened +/// : memref> to memref<3xvector<[16]xi1>> +/// %cast = builtin.unrealized_conversion_cast %widened_cast +/// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>> +/// {__arm_sve_legalize_vector_storage__} +/// ``` +struct LegalizeSVEMaskTypeCastConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp, + PatternRewriter &rewriter) const override { + auto resultType = typeCastOp.getResultMemRefType(); + auto vectorType = llvm::dyn_cast(resultType.getElementType()); + + if (!vectorType || !isSVEMaskType(vectorType)) + return failure(); + + auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref()); + if (failed(legalMemref)) + return failure(); + + // Replace this vector.type_cast with one of a (storable) svbool mask [1]. + replaceOpWithUnrealizedConversion( + rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) { + newTypeCast.setOperand(*legalMemref); + newTypeCast.getResult().setType( + llvm::cast(newTypeCast.getType().cloneWith( + {}, widenScalableMaskTypeToSvbool(vectorType)))); + return newTypeCast; + }); + + return success(); + } +}; + +/// Replaces stores to unrealized conversions to SVE predicate memref types that +/// are _illegal_ to load/store from (!= svbool [1]), with +/// `arm_sve.convert_to_svbool`s followed by (legal) wider stores. +/// +/// Example: +/// ``` +/// memref.store %mask, %alloca[] : memref> +/// ``` +/// is rewritten into: +/// ``` +/// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1> +/// memref.store %svbool, %widened[] : memref> +/// ``` +struct LegalizeSVEMaskStoreConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + + Value valueToStore = storeOp.getValueToStore(); + auto vectorType = llvm::dyn_cast(valueToStore.getType()); + + if (!vectorType || !isSVEMaskType(vectorType)) + return failure(); + + auto legalMemref = getSVELegalizedMemref(storeOp.getMemref()); + if (failed(legalMemref)) + return failure(); + + auto legalMaskType = widenScalableMaskTypeToSvbool( + llvm::cast(valueToStore.getType())); + auto convertToSvbool = rewriter.create( + loc, legalMaskType, valueToStore); + // Replace this store with a conversion to a storable svbool mask [1], + // followed by a wider store. + replaceOpWithLegalizedOp(rewriter, storeOp, + [&](memref::StoreOp newStoreOp) { + newStoreOp.setOperand(0, convertToSvbool); + newStoreOp.setOperand(1, *legalMemref); + return newStoreOp; + }); + + return success(); + } +}; + +/// Replaces loads from unrealized conversions to SVE predicate memref types +/// that are _illegal_ to load/store from (!= svbool [1]), types with (legal) +/// wider loads, followed by `arm_sve.convert_from_svbool`s. +/// +/// Example: +/// ``` +/// %reload = memref.load %alloca[] : memref> +/// ``` +/// is rewritten into: +/// ``` +/// %svbool = memref.load %widened[] : memref> +/// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1> +/// ``` +struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + + Value loadedMask = loadOp.getResult(); + auto vectorType = llvm::dyn_cast(loadedMask.getType()); + + if (!vectorType || !isSVEMaskType(vectorType)) + return failure(); + + auto legalMemref = getSVELegalizedMemref(loadOp.getMemref()); + if (failed(legalMemref)) + return failure(); + + auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType); + // Replace this load with a legal load of an svbool type, followed by a + // conversion back to the original type. + replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) { + newLoadOp.setMemRef(*legalMemref); + newLoadOp.getResult().setType(legalMaskType); + return rewriter.create( + loc, loadedMask.getType(), newLoadOp); + }); + + return success(); + } +}; + +} // namespace + +void mlir::arm_sve::populateLegalizeVectorStoragePatterns( + RewritePatternSet &patterns) { + patterns.add, + LegalizeSVEMaskAllocation, + LegalizeSVEMaskTypeCastConversion, + LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>( + patterns.getContext()); +} + +namespace { +struct LegalizeVectorStorage + : public arm_sve::impl::LegalizeVectorStorageBase { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateLegalizeVectorStoragePatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + ConversionTarget target(getContext()); + target.addDynamicallyLegalOp( + [](UnrealizedConversionCastOp unrealizedConversion) { + return !unrealizedConversion->hasAttr(kSVELegalizerTag); + }); + // This detects if we failed to completely legalize the IR. + if (failed(applyPartialConversion(getOperation(), target, {}))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::arm_sve::createLegalizeVectorStoragePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir new file mode 100644 index 0000000000000..9a3df8376f121 --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir @@ -0,0 +1,203 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s + +/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass. + +// ----- + +// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1( +// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>) +func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1> + // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1> + return %reload : vector<[1]xi1> +} + +// ----- + +// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1( +// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>) +func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1> + // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1> + return %reload : vector<[2]xi1> +} + +// ----- + +// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1( +// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>) +func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1> + // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1> + return %reload : vector<[4]xi1> +} + +// ----- + +// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1( +// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>) +func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1> + // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1> + return %reload : vector<[8]xi1> +} + +// ----- + +// CHECK-LABEL: @store_and_reload_sve_predicate_nxv16i1( +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) +func.func @store_and_reload_sve_predicate_nxv16i1(%mask: vector<[16]xi1>) -> vector<[16]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[RELOAD]] : vector<[16]xi1> + return %reload : vector<[16]xi1> +} + +// ----- + +/// This is not a valid SVE mask type, so is ignored by the +// `-arm-sve-legalize-vector-storage` pass. + +// CHECK-LABEL: @store_and_reload_unsupported_type( +// CHECK-SAME: %[[MASK:.*]]: vector<[7]xi1>) +func.func @store_and_reload_unsupported_type(%mask: vector<[7]xi1>) -> vector<[7]xi1> { + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> + %reload = memref.load %alloca[] : memref> + // CHECK-NEXT: return %[[RELOAD]] : vector<[7]xi1> + return %reload : vector<[7]xi1> +} + +// ----- + +// CHECK-LABEL: @store_2d_mask_and_reload_slice( +// CHECK-SAME: %[[MASK:.*]]: vector<3x[8]xi1>) +func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> { + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> + %alloca = memref.alloca() : memref> + // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1> + // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> + memref.store %mask, %alloca[] : memref> + // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref> to memref<3xvector<[16]xi1>> + %unpack = vector.type_cast %alloca : memref> to memref<3xvector<[8]xi1>> + // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>> + // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> + %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>> + // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1> + return %slice : vector<[8]xi1> +} + +// ----- + +// CHECK-LABEL: @set_sve_alloca_alignment +func.func @set_sve_alloca_alignment() { + /// This checks the alignment of alloca's of scalable vectors will be + /// something the backend can handle. Currently, the backend sets the + /// alignment of scalable vectors to their base size (i.e. their size at + /// vscale = 1). This works for hardware-sized types, which always get a + /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up + /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are + /// unsupported. The `-arm-sve-legalize-vector-storage` pass avoids this + /// issue by explicitly setting the alignment to 16-bytes for all scalable + /// vectors. + + // CHECK-COUNT-6: alignment = 16 + %a1 = memref.alloca() : memref> + %a2 = memref.alloca() : memref> + %a3 = memref.alloca() : memref> + %a4 = memref.alloca() : memref> + %a5 = memref.alloca() : memref> + %a6 = memref.alloca() : memref> + + // CHECK-COUNT-6: alignment = 16 + %b1 = memref.alloca() : memref> + %b2 = memref.alloca() : memref> + %b3 = memref.alloca() : memref> + %b4 = memref.alloca() : memref> + %b5 = memref.alloca() : memref> + %b6 = memref.alloca() : memref> + + // CHECK-COUNT-6: alignment = 16 + %c1 = memref.alloca() : memref> + %c2 = memref.alloca() : memref> + %c3 = memref.alloca() : memref> + %c4 = memref.alloca() : memref> + %c5 = memref.alloca() : memref> + %c6 = memref.alloca() : memref> + + // CHECK-COUNT-6: alignment = 16 + %d1 = memref.alloca() : memref> + %d2 = memref.alloca() : memref> + %d3 = memref.alloca() : memref> + %d4 = memref.alloca() : memref> + %d5 = memref.alloca() : memref> + %d6 = memref.alloca() : memref> + + // CHECK-COUNT-6: alignment = 16 + %e1 = memref.alloca() : memref> + %e2 = memref.alloca() : memref> + %e3 = memref.alloca() : memref> + %e4 = memref.alloca() : memref> + %e5 = memref.alloca() : memref> + %e6 = memref.alloca() : memref> + + // CHECK-COUNT-6: alignment = 16 + %f1 = memref.alloca() : memref> + %f2 = memref.alloca() : memref> + %f3 = memref.alloca() : memref> + %f4 = memref.alloca() : memref> + %f5 = memref.alloca() : memref> + %f6 = memref.alloca() : memref> + + "prevent.dce"( + %a1, %a2, %a3, %a4, %a5, %a6, + %b1, %b2, %b3, %b4, %b5, %b6, + %c1, %c2, %c3, %c4, %c5, %c6, + %d1, %d2, %d3, %d4, %d5, %d6, + %e1, %e2, %e3, %e4, %e5, %e6, + %f1, %f2, %f3, %f4, %f5, %f6) + : (memref>, memref>, memref>, memref>, memref>, memref>, + memref>, memref>, memref>, memref>, memref>, memref>, + memref>, memref>, memref>, memref>, memref>, memref>, + memref>, memref>, memref>, memref>, memref>, memref>, + memref>, memref>, memref>, memref>, memref>, memref>, + memref>, memref>, memref>, memref>, memref>, memref>) -> () + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir new file mode 100644 index 0000000000000..c486bf0de5d35 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir @@ -0,0 +1,117 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \ +// RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +/// This tests basic functionality of arrays of scalable vectors, which in MLIR +/// are vectors with a single trailing scalable dimension. This test requires +/// the -arm-sve-legalize-vector-storage pass to ensure the loads/stores done +/// here are be legal for the LLVM backend. + +func.func @read_and_print_2d_vector(%memref: memref<3x?xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim = memref.dim %memref, %c1 : memref<3x?xf32> + %mask = vector.create_mask %c2, %dim : vector<3x[8]xi1> + %vector = vector.transfer_read %memref[%c0,%c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[8]xf32> + + /// TODO: Support vector.print for arrays of scalable vectors. + %row0 = vector.extract %vector[0] : vector<[8]xf32> from vector<3x[8]xf32> + %row1 = vector.extract %vector[1] : vector<[8]xf32> from vector<3x[8]xf32> + %row2 = vector.extract %vector[2] : vector<[8]xf32> from vector<3x[8]xf32> + + /// Print each of the vectors. + /// vscale is >= 1, so at least 8 elements will be printed. + + vector.print str "read_and_print_2d_vector()" + // CHECK-LABEL: read_and_print_2d_vector() + // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8 + vector.print %row0 : vector<[8]xf32> + // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8 + vector.print %row1 : vector<[8]xf32> + /// This last row is all zero due to our mask. + // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0 + vector.print %row2 : vector<[8]xf32> + + return +} + +func.func @print_1x2xVSCALExf32(%vector: vector<1x2x[4]xf32>) { + /// TODO: Support vector.print for arrays of scalable vectors. + %slice0 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32> + %slice1 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32> + vector.print %slice0 : vector<[4]xf32> + vector.print %slice1 : vector<[4]xf32> + return +} + +func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 2 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32> + %dim_b = memref.dim %b, %c2 : memref<1x2x?xf32> + %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1> + %mask_b = vector.create_mask %c2, %c3, %dim_b : vector<1x2x[4]xi1> + + /// Print each of the vectors. + /// vscale is >= 1, so at least 4 elements will be printed. + + // CHECK-LABEL: Vector A + // CHECK-NEXT: ( 5, 5, 5, 5 + // CHECK-NEXT: ( 5, 5, 5, 5 + vector.print str "\nVector A" + %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32> + func.call @print_1x2xVSCALExf32(%vector_a) : (vector<1x2x[4]xf32>) -> () + + // CHECK-LABEL: Vector B + // CHECK-NEXT: ( 4, 4, 4, 4 + // CHECK-NEXT: ( 4, 4, 4, 4 + vector.print str "\nVector B" + %vector_b = vector.transfer_read %b[%c0, %c0, %c0], %cst, %mask_b {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32> + func.call @print_1x2xVSCALExf32(%vector_b) : (vector<1x2x[4]xf32>) -> () + + // CHECK-LABEL: Sum + // CHECK-NEXT: ( 9, 9, 9, 9 + // CHECK-NEXT: ( 9, 9, 9, 9 + vector.print str "\nSum" + %sum = arith.addf %vector_a, %vector_b : vector<1x2x[4]xf32> + func.call @print_1x2xVSCALExf32(%sum) : (vector<1x2x[4]xf32>) -> () + + return +} + +func.func @entry() { + %vscale = vector.vscale + + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %f32_8 = arith.constant 8.0 : f32 + %f32_5 = arith.constant 5.0 : f32 + %f32_4 = arith.constant 4.0 : f32 + + %test_1_memref_size = arith.muli %vscale, %c8 : index + %test_1_memref = memref.alloca(%test_1_memref_size) : memref<3x?xf32> + + linalg.fill ins(%f32_8 : f32) outs(%test_1_memref :memref<3x?xf32>) + + vector.print str "=> Print and read 2D arrays of scalable vectors:" + func.call @read_and_print_2d_vector(%test_1_memref) : (memref<3x?xf32>) -> () + + vector.print str "\n====================\n" + + %test_2_memref_size = arith.muli %vscale, %c4 : index + %test_2_memref_a = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32> + %test_2_memref_b = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32> + + linalg.fill ins(%f32_5 : f32) outs(%test_2_memref_a :memref<1x2x?xf32>) + linalg.fill ins(%f32_4 : f32) outs(%test_2_memref_b :memref<1x2x?xf32>) + + vector.print str "=> Reading and adding two 3D arrays of scalable vectors:" + func.call @add_arrays_of_scalable_vectors( + %test_2_memref_a, %test_2_memref_b) : (memref<1x2x?xf32>, memref<1x2x?xf32>) -> () + + return +}