From 5cba884e9f66608bea6a19e38fd298e101fd0214 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Wed, 14 Feb 2024 22:24:24 +0000 Subject: [PATCH 1/2] [mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and properties instead. --- .../mlir/Dialect/SparseTensor/IR/Enums.h | 516 +++++++----------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 +- .../SparseTensor/IR/Detail/LvlTypeParser.cpp | 4 +- .../SparseTensor/IR/SparseTensorDialect.cpp | 16 +- .../Transforms/SparseTensorRewriting.cpp | 2 +- .../Transforms/Utils/SparseTensorLevel.cpp | 6 +- .../lib/Dialect/SparseTensor/Utils/Merger.cpp | 3 +- .../Dialect/SparseTensor/MergerTest.cpp | 34 +- 8 files changed, 237 insertions(+), 354 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 74cc0dee554a1..079899a147476 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -153,45 +153,9 @@ enum class Action : uint32_t { kSortCOOInPlace = 8, }; -/// This enum defines all the sparse representations supportable by -/// the SparseTensor dialect. We use a lightweight encoding to encode -/// the "format" per se (dense, compressed, singleton, loose_compressed, -/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when -/// the format is NOutOfM. -/// The encoding is chosen for performance of the runtime library, and thus may -/// change in future versions; consequently, client code should use the -/// predicate functions defined below, rather than relying on knowledge -/// about the particular binary encoding. -/// -/// The `Undef` "format" is a special value used internally for cases -/// where we need to store an undefined or indeterminate `LevelType`. -/// It should not be used externally, since it does not indicate an -/// actual/representable format. -/// -/// Bit manipulations for LevelType: -/// -/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | -/// -enum class LevelType : uint64_t { - Undef = 0x000000000000, - Dense = 0x000000010000, - Compressed = 0x000000020000, - CompressedNu = 0x000000020001, - CompressedNo = 0x000000020002, - CompressedNuNo = 0x000000020003, - Singleton = 0x000000040000, - SingletonNu = 0x000000040001, - SingletonNo = 0x000000040002, - SingletonNuNo = 0x000000040003, - LooseCompressed = 0x000000080000, - LooseCompressedNu = 0x000000080001, - LooseCompressedNo = 0x000000080002, - LooseCompressedNuNo = 0x000000080003, - NOutOfM = 0x000000100000, -}; - /// This enum defines all supported storage format without the level properties. enum class LevelFormat : uint64_t { + Undef = 0x00000000, Dense = 0x00010000, Compressed = 0x00020000, Singleton = 0x00040000, @@ -199,328 +163,236 @@ enum class LevelFormat : uint64_t { NOutOfM = 0x00100000, }; +/// Returns string representation of the given level format. +constexpr const char *toFormatString(LevelFormat lvlFmt) { + switch (lvlFmt) { + case LevelFormat::Undef: + return "undef"; + case LevelFormat::Dense: + return "dense"; + case LevelFormat::Compressed: + return "compressed"; + case LevelFormat::Singleton: + return "singleton"; + case LevelFormat::LooseCompressed: + return "loose_compressed"; + case LevelFormat::NOutOfM: + return "structured"; + } + return ""; +} + /// This enum defines all the nondefault properties for storage formats. -enum class LevelPropertyNondefault : uint64_t { +enum class LevelPropNonDefault : uint64_t { Nonunique = 0x0001, Nonordered = 0x0002, }; -/// Get N of NOutOfM level type. -constexpr uint64_t getN(LevelType lt) { - return (static_cast(lt) >> 32) & 0xff; +/// Returns string representation of the given level properties. +constexpr const char *toPropString(LevelPropNonDefault lvlProp) { + switch (lvlProp) { + case LevelPropNonDefault::Nonunique: + return "nonunique"; + case LevelPropNonDefault::Nonordered: + return "nonordered"; + } + return ""; } -/// Get M of NOutOfM level type. -constexpr uint64_t getM(LevelType lt) { - return (static_cast(lt) >> 40) & 0xff; -} +/// This enum defines all the sparse representations supportable by +/// the SparseTensor dialect. We use a lightweight encoding to encode +/// the "format" per se (dense, compressed, singleton, loose_compressed, +/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when +/// the format is NOutOfM. +/// The encoding is chosen for performance of the runtime library, and thus may +/// change in future versions; consequently, client code should use the +/// predicate functions defined below, rather than relying on knowledge +/// about the particular binary encoding. +/// +/// The `Undef` "format" is a special value used internally for cases +/// where we need to store an undefined or indeterminate `LevelType`. +/// It should not be used externally, since it does not indicate an +/// actual/representable format. -/// Convert N of NOutOfM level type to the stored bits. -constexpr uint64_t nToBits(uint64_t n) { return n << 32; } +struct LevelType { +public: + /// Check that the `LevelType` contains a valid (possibly undefined) value. + static constexpr bool isValidLvlBits(uint64_t lvlBits) { + const uint64_t formatBits = lvlBits & 0xffff0000; + const uint64_t propertyBits = lvlBits & 0xffff; + // If undefined/dense/NOutOfM, then must be unique and ordered. + // Otherwise, the format must be one of the known ones. + return (formatBits <= 0x10000 || formatBits == 0x100000) + ? (propertyBits == 0) + : (formatBits == 0x20000 || formatBits == 0x40000 || + formatBits == 0x80000); + } -/// Convert M of NOutOfM level type to the stored bits. -constexpr uint64_t mToBits(uint64_t m) { return m << 40; } + /// Convert a LevelFormat to its corresponding LevelType with the given + /// properties. Returns std::nullopt when the properties are not applicable + /// for the input level format. + static std::optional + buildLvlType(LevelFormat lf, + const std::vector &properties, + uint64_t n = 0, uint64_t m = 0) { + uint64_t newN = n << 32; + uint64_t newM = m << 40; + uint64_t ltBits = static_cast(lf) | newN | newM; + for (auto p : properties) + ltBits |= static_cast(p); + + return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits)) + : std::nullopt; + } + static std::optional buildLvlType(LevelFormat lf, bool ordered, + bool unique, uint64_t n = 0, + uint64_t m = 0) { + std::vector properties; + if (!ordered) + properties.push_back(LevelPropNonDefault::Nonordered); + if (!unique) + properties.push_back(LevelPropNonDefault::Nonunique); + return buildLvlType(lf, properties, n, m); + } -/// Check if the `LevelType` is NOutOfM (regardless of -/// properties and block sizes). -constexpr bool isNOutOfMLT(LevelType lt) { - return ((static_cast(lt) & 0x100000) == - static_cast(LevelType::NOutOfM)); -} + /// Explicit conversion from uint64_t. + constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) { + assert(isValidLvlBits(bits)); + }; -/// Check if the `LevelType` is NOutOfM with the correct block sizes. -constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { - return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m; -} + /// Constructs a LevelType with the given format using all default properties. + /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast(f)) { + assert(isValidLvlBits(lvlBits) && !isa()); + }; -/// Returns string representation of the given dimension level type. -constexpr const char *toMLIRString(LevelType lvlType) { - auto lt = static_cast(static_cast(lvlType) & 0xffffffff); - switch (lt) { - case LevelType::Undef: - return "undef"; - case LevelType::Dense: - return "dense"; - case LevelType::Compressed: - return "compressed"; - case LevelType::CompressedNu: - return "compressed(nonunique)"; - case LevelType::CompressedNo: - return "compressed(nonordered)"; - case LevelType::CompressedNuNo: - return "compressed(nonunique, nonordered)"; - case LevelType::Singleton: - return "singleton"; - case LevelType::SingletonNu: - return "singleton(nonunique)"; - case LevelType::SingletonNo: - return "singleton(nonordered)"; - case LevelType::SingletonNuNo: - return "singleton(nonunique, nonordered)"; - case LevelType::LooseCompressed: - return "loose_compressed"; - case LevelType::LooseCompressedNu: - return "loose_compressed(nonunique)"; - case LevelType::LooseCompressedNo: - return "loose_compressed(nonordered)"; - case LevelType::LooseCompressedNuNo: - return "loose_compressed(nonunique, nonordered)"; - case LevelType::NOutOfM: - return "structured"; - } - return ""; -} + /// Converts to uint64_t + explicit operator uint64_t() const { return lvlBits; } -/// Check that the `LevelType` contains a valid (possibly undefined) value. -constexpr bool isValidLT(LevelType lt) { - const uint64_t formatBits = static_cast(lt) & 0xffff0000; - const uint64_t propertyBits = static_cast(lt) & 0xffff; - // If undefined/dense/NOutOfM, then must be unique and ordered. - // Otherwise, the format must be one of the known ones. - return (formatBits <= 0x10000 || formatBits == 0x100000) - ? (propertyBits == 0) - : (formatBits == 0x20000 || formatBits == 0x40000 || - formatBits == 0x80000); -} + bool operator==(const LevelType lhs) const { + return static_cast(lhs) == lvlBits; + } + bool operator!=(const LevelType lhs) const { return !(*this == lhs); } -/// Check if the `LevelType` is the special undefined value. -constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; } + LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); } -/// Check if the `LevelType` is dense (regardless of properties). -constexpr bool isDenseLT(LevelType lt) { - return (static_cast(lt) & ~0xffff) == - static_cast(LevelType::Dense); -} + /// Get N/M of NOutOfM level type. + constexpr uint64_t getN() const { + assert(isa()); + return (lvlBits >> 32) & 0xff; + } + constexpr uint64_t getM() const { + assert(isa()); + return (lvlBits >> 40) & 0xff; + } -/// Check if the `LevelType` is compressed (regardless of properties). -constexpr bool isCompressedLT(LevelType lt) { - return (static_cast(lt) & ~0xffff) == - static_cast(LevelType::Compressed); -} + /// Get the `LevelFormat` of the `LevelType`. + LevelFormat getLvlFmt() const { + return static_cast(lvlBits & 0xffff0000); + } -/// Check if the `LevelType` is singleton (regardless of properties). -constexpr bool isSingletonLT(LevelType lt) { - return (static_cast(lt) & ~0xffff) == - static_cast(LevelType::Singleton); -} + /// Check if the `LevelType` is in the `LevelFormat`. + template + bool isa() const { + return getLvlFmt() == fmt; + } -/// Check if the `LevelType` is loose compressed (regardless of properties). -constexpr bool isLooseCompressedLT(LevelType lt) { - return (static_cast(lt) & ~0xffff) == - static_cast(LevelType::LooseCompressed); -} + /// Check if the `LevelType` has the properties + template + bool isa() const { + return lvlBits & static_cast(p); + } -/// Check if the `LevelType` needs positions array. -constexpr bool isWithPosLT(LevelType lt) { - return isCompressedLT(lt) || isLooseCompressedLT(lt); -} + /// Check if the `LevelType` needs positions array. + bool isWithPosLT() const { + return isa() || + isa(); + } -/// Check if the `LevelType` needs coordinates array. -constexpr bool isWithCrdLT(LevelType lt) { - return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || - isNOutOfMLT(lt); -} + /// Check if the `LevelType` needs coordinates array. + constexpr bool isWithCrdLT() const { + // All sparse levels has coordinate array. + return !isa(); + } -/// Check if the `LevelType` is ordered (regardless of storage format). -constexpr bool isOrderedLT(LevelType lt) { - return !(static_cast(lt) & 2); - return !(static_cast(lt) & 2); -} + std::string toMLIRString() const { + std::string lvlStr = toFormatString(getLvlFmt()); + std::string propStr = ""; + if (isa()) + propStr += toPropString(LevelPropNonDefault::Nonunique); + + if (isa()) { + if (!propStr.empty()) + propStr += ", "; + propStr += toPropString(LevelPropNonDefault::Nonordered); + } + if (!propStr.empty()) + lvlStr += ("(" + propStr + ")"); + return lvlStr; + } -/// Check if the `LevelType` is unique (regardless of storage format). -constexpr bool isUniqueLT(LevelType lt) { - return !(static_cast(lt) & 1); - return !(static_cast(lt) & 1); -} +private: + /// Bit manipulations for LevelType: + /// + /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | + /// + uint64_t lvlBits; +}; -/// Convert a LevelType to its corresponding LevelFormat. -/// Returns std::nullopt when input lt is Undef. -constexpr std::optional getLevelFormat(LevelType lt) { - if (lt == LevelType::Undef) - return std::nullopt; - return static_cast(static_cast(lt) & 0xffff0000); -} +// For backward-compatibility. TODO: remove below after fully migration. +constexpr uint64_t nToBits(uint64_t n) { return n << 32; } +constexpr uint64_t mToBits(uint64_t m) { return m << 40; } -/// Convert a LevelFormat to its corresponding LevelType with the given -/// properties. Returns std::nullopt when the properties are not applicable -/// for the input level format. inline std::optional buildLevelType(LevelFormat lf, - const std::vector &properties, + const std::vector &properties, uint64_t n = 0, uint64_t m = 0) { - uint64_t newN = n << 32; - uint64_t newM = m << 40; - uint64_t ltInt = static_cast(lf) | newN | newM; - for (auto p : properties) { - ltInt |= static_cast(p); - } - auto lt = static_cast(ltInt); - return isValidLT(lt) ? std::optional(lt) : std::nullopt; + return LevelType::buildLvlType(lf, properties, n, m); } - inline std::optional buildLevelType(LevelFormat lf, bool ordered, bool unique, uint64_t n = 0, uint64_t m = 0) { - std::vector properties; - if (!ordered) - properties.push_back(LevelPropertyNondefault::Nonordered); - if (!unique) - properties.push_back(LevelPropertyNondefault::Nonunique); - return buildLevelType(lf, properties, n, m); + return LevelType::buildLvlType(lf, ordered, unique, n, m); +} +inline bool isUndefLT(LevelType lt) { return lt.isa(); } +inline bool isDenseLT(LevelType lt) { return lt.isa(); } +inline bool isCompressedLT(LevelType lt) { + return lt.isa(); +} +inline bool isLooseCompressedLT(LevelType lt) { + return lt.isa(); +} +inline bool isSingletonLT(LevelType lt) { + return lt.isa(); +} +inline bool isNOutOfMLT(LevelType lt) { return lt.isa(); } +inline bool isOrderedLT(LevelType lt) { + return !lt.isa(); } +inline bool isUniqueLT(LevelType lt) { + return !lt.isa(); +} +inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); } +inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); } +inline bool isValidLT(LevelType lt) { + return LevelType::isValidLvlBits(static_cast(lt)); +} +inline std::optional getLevelFormat(LevelType lt) { + LevelFormat fmt = lt.getLvlFmt(); + if (fmt == LevelFormat::Undef) + return std::nullopt; + return fmt; +} +inline uint64_t getN(LevelType lt) { return lt.getN(); } +inline uint64_t getM(LevelType lt) { return lt.getM(); } +inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { + return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m; +} +inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); } // // Ensure the above methods work as intended. // -static_assert( - (getLevelFormat(LevelType::Undef) == std::nullopt && - *getLevelFormat(LevelType::Dense) == LevelFormat::Dense && - *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed && - *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed && - *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed && - *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed && - *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton && - *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton && - *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton && - *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton && - *getLevelFormat(LevelType::LooseCompressed) == - LevelFormat::LooseCompressed && - *getLevelFormat(LevelType::LooseCompressedNu) == - LevelFormat::LooseCompressed && - *getLevelFormat(LevelType::LooseCompressedNo) == - LevelFormat::LooseCompressed && - *getLevelFormat(LevelType::LooseCompressedNuNo) == - LevelFormat::LooseCompressed && - *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM), - "getLevelFormat conversion is broken"); - -static_assert( - (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) && - isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) && - isValidLT(LevelType::CompressedNo) && - isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) && - isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) && - isValidLT(LevelType::SingletonNuNo) && - isValidLT(LevelType::LooseCompressed) && - isValidLT(LevelType::LooseCompressedNu) && - isValidLT(LevelType::LooseCompressedNo) && - isValidLT(LevelType::LooseCompressedNuNo) && - isValidLT(LevelType::NOutOfM)), - "isValidLT definition is broken"); - -static_assert((isDenseLT(LevelType::Dense) && - !isDenseLT(LevelType::Compressed) && - !isDenseLT(LevelType::CompressedNu) && - !isDenseLT(LevelType::CompressedNo) && - !isDenseLT(LevelType::CompressedNuNo) && - !isDenseLT(LevelType::Singleton) && - !isDenseLT(LevelType::SingletonNu) && - !isDenseLT(LevelType::SingletonNo) && - !isDenseLT(LevelType::SingletonNuNo) && - !isDenseLT(LevelType::LooseCompressed) && - !isDenseLT(LevelType::LooseCompressedNu) && - !isDenseLT(LevelType::LooseCompressedNo) && - !isDenseLT(LevelType::LooseCompressedNuNo) && - !isDenseLT(LevelType::NOutOfM)), - "isDenseLT definition is broken"); - -static_assert((!isCompressedLT(LevelType::Dense) && - isCompressedLT(LevelType::Compressed) && - isCompressedLT(LevelType::CompressedNu) && - isCompressedLT(LevelType::CompressedNo) && - isCompressedLT(LevelType::CompressedNuNo) && - !isCompressedLT(LevelType::Singleton) && - !isCompressedLT(LevelType::SingletonNu) && - !isCompressedLT(LevelType::SingletonNo) && - !isCompressedLT(LevelType::SingletonNuNo) && - !isCompressedLT(LevelType::LooseCompressed) && - !isCompressedLT(LevelType::LooseCompressedNu) && - !isCompressedLT(LevelType::LooseCompressedNo) && - !isCompressedLT(LevelType::LooseCompressedNuNo) && - !isCompressedLT(LevelType::NOutOfM)), - "isCompressedLT definition is broken"); - -static_assert((!isSingletonLT(LevelType::Dense) && - !isSingletonLT(LevelType::Compressed) && - !isSingletonLT(LevelType::CompressedNu) && - !isSingletonLT(LevelType::CompressedNo) && - !isSingletonLT(LevelType::CompressedNuNo) && - isSingletonLT(LevelType::Singleton) && - isSingletonLT(LevelType::SingletonNu) && - isSingletonLT(LevelType::SingletonNo) && - isSingletonLT(LevelType::SingletonNuNo) && - !isSingletonLT(LevelType::LooseCompressed) && - !isSingletonLT(LevelType::LooseCompressedNu) && - !isSingletonLT(LevelType::LooseCompressedNo) && - !isSingletonLT(LevelType::LooseCompressedNuNo) && - !isSingletonLT(LevelType::NOutOfM)), - "isSingletonLT definition is broken"); - -static_assert((!isLooseCompressedLT(LevelType::Dense) && - !isLooseCompressedLT(LevelType::Compressed) && - !isLooseCompressedLT(LevelType::CompressedNu) && - !isLooseCompressedLT(LevelType::CompressedNo) && - !isLooseCompressedLT(LevelType::CompressedNuNo) && - !isLooseCompressedLT(LevelType::Singleton) && - !isLooseCompressedLT(LevelType::SingletonNu) && - !isLooseCompressedLT(LevelType::SingletonNo) && - !isLooseCompressedLT(LevelType::SingletonNuNo) && - isLooseCompressedLT(LevelType::LooseCompressed) && - isLooseCompressedLT(LevelType::LooseCompressedNu) && - isLooseCompressedLT(LevelType::LooseCompressedNo) && - isLooseCompressedLT(LevelType::LooseCompressedNuNo) && - !isLooseCompressedLT(LevelType::NOutOfM)), - "isLooseCompressedLT definition is broken"); - -static_assert((!isNOutOfMLT(LevelType::Dense) && - !isNOutOfMLT(LevelType::Compressed) && - !isNOutOfMLT(LevelType::CompressedNu) && - !isNOutOfMLT(LevelType::CompressedNo) && - !isNOutOfMLT(LevelType::CompressedNuNo) && - !isNOutOfMLT(LevelType::Singleton) && - !isNOutOfMLT(LevelType::SingletonNu) && - !isNOutOfMLT(LevelType::SingletonNo) && - !isNOutOfMLT(LevelType::SingletonNuNo) && - !isNOutOfMLT(LevelType::LooseCompressed) && - !isNOutOfMLT(LevelType::LooseCompressedNu) && - !isNOutOfMLT(LevelType::LooseCompressedNo) && - !isNOutOfMLT(LevelType::LooseCompressedNuNo) && - isNOutOfMLT(LevelType::NOutOfM)), - "isNOutOfMLT definition is broken"); - -static_assert((isOrderedLT(LevelType::Dense) && - isOrderedLT(LevelType::Compressed) && - isOrderedLT(LevelType::CompressedNu) && - !isOrderedLT(LevelType::CompressedNo) && - !isOrderedLT(LevelType::CompressedNuNo) && - isOrderedLT(LevelType::Singleton) && - isOrderedLT(LevelType::SingletonNu) && - !isOrderedLT(LevelType::SingletonNo) && - !isOrderedLT(LevelType::SingletonNuNo) && - isOrderedLT(LevelType::LooseCompressed) && - isOrderedLT(LevelType::LooseCompressedNu) && - !isOrderedLT(LevelType::LooseCompressedNo) && - !isOrderedLT(LevelType::LooseCompressedNuNo) && - isOrderedLT(LevelType::NOutOfM)), - "isOrderedLT definition is broken"); - -static_assert((isUniqueLT(LevelType::Dense) && - isUniqueLT(LevelType::Compressed) && - !isUniqueLT(LevelType::CompressedNu) && - isUniqueLT(LevelType::CompressedNo) && - !isUniqueLT(LevelType::CompressedNuNo) && - isUniqueLT(LevelType::Singleton) && - !isUniqueLT(LevelType::SingletonNu) && - isUniqueLT(LevelType::SingletonNo) && - !isUniqueLT(LevelType::SingletonNuNo) && - isUniqueLT(LevelType::LooseCompressed) && - !isUniqueLT(LevelType::LooseCompressedNu) && - isUniqueLT(LevelType::LooseCompressedNo) && - !isUniqueLT(LevelType::LooseCompressedNuNo) && - isUniqueLT(LevelType::NOutOfM)), - "isUniqueLT definition is broken"); - /// Bit manipulations for affine encoding. /// /// Note that because the indices in the mappings refer to dimensions diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 55af8becbba20..3ae06f220c528 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -34,9 +34,9 @@ static_assert( "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == - static_cast(LevelPropertyNondefault::Nonordered) && + static_cast(LevelPropNonDefault::Nonordered) && static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast(LevelPropertyNondefault::Nonunique), + static_cast(LevelPropNonDefault::Nonunique), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { LevelType lt = static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); - return static_cast(*getLevelFormat(lt)); + return static_cast(lt.getLvlFmt()); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( const enum MlirSparseTensorLevelPropertyNondefault *properties, unsigned size, unsigned n, unsigned m) { - std::vector props; + std::vector props; for (unsigned i = 0; i < size; i++) - props.push_back(static_cast(properties[i])); + props.push_back(static_cast(properties[i])); return static_cast( *buildLevelType(static_cast(lvlFmt), props, n, m)); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp index 0fb0d2761054b..380cccc989ec6 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser, ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)), "expected valid level property (e.g. nonordered, nonunique or high)") if (strVal.compare("nonunique") == 0) { - *properties |= static_cast(LevelPropertyNondefault::Nonunique); + *properties |= static_cast(LevelPropNonDefault::Nonunique); } else if (strVal.compare("nonordered") == 0) { - *properties |= static_cast(LevelPropertyNondefault::Nonordered); + *properties |= static_cast(LevelPropNonDefault::Nonordered); } else { parser.emitError(loc, "unknown level property: ") << strVal; return failure(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index aed43f26d54f1..6d02645d860e9 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -35,6 +35,14 @@ using namespace mlir; using namespace mlir::sparse_tensor; +// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as +// well. +namespace mlir::sparse_tensor { +llvm::hash_code hash_value(LevelType lt) { + return llvm::hash_value(static_cast(lt)); +} +} // namespace mlir::sparse_tensor + //===----------------------------------------------------------------------===// // Local Convenience Methods. //===----------------------------------------------------------------------===// @@ -83,11 +91,11 @@ void StorageLayout::foreachField( } // The values array. if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, - LevelType::Undef))) + LevelFormat::Undef))) return; // Put metadata at the end. if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, - LevelType::Undef))) + LevelFormat::Undef))) return; } @@ -341,7 +349,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const { LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { if (!getImpl()) - return LevelType::Dense; + return LevelFormat::Dense; assert(l < getLvlRank() && "Level is out of bounds"); return getLvlTypes()[l]; } @@ -975,7 +983,7 @@ static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector lts; for (auto lt : enc.getLvlTypes()) - lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true)); + lts.push_back(lt.stripProperties()); return SparseTensorEncodingAttr::get( enc.getContext(), lts, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 235c5453f9cc9..7326a6a381128 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -46,7 +46,7 @@ static bool isZeroValue(Value val) { static bool isSparseTensor(Value v) { auto enc = getSparseTensorEncoding(v.getType()); return enc && !llvm::all_of(enc.getLvlTypes(), - [](auto lt) { return lt == LevelType::Dense; }); + [](auto lt) { return lt == LevelFormat::Dense; }); } static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp index c85f8204ba752..61a3703b73bf0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp @@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel { class DenseLevel : public SparseTensorLevel { public: DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded) - : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize), + : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize), encoded(encoded) {} Value peekCrdAt(OpBuilder &, Location, Value pos) const override { @@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value sz = stt.hasEncoding() ? b.create(l, t, lvl).getResult() : b.create(l, t, lvl).getResult(); - switch (*getLevelFormat(lt)) { + switch (lt.getLvlFmt()) { case LevelFormat::Dense: return std::make_unique(tid, lvl, sz, stt.hasEncoding()); case LevelFormat::Compressed: { @@ -1296,6 +1296,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value crd = genToCoordinates(b, l, t, lvl); return std::make_unique(tid, lvl, lt, sz, crd); } + case LevelFormat::Undef: + llvm_unreachable("undefined level format"); } llvm_unreachable("unrecognizable level format"); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 96537cbb0c483..731cd79a1e3b4 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -226,7 +226,8 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, syntheticTensor(numInputOutputTensors), numTensors(numInputOutputTensors + 1), numLoops(numLoops), hasSparseOut(false), - lvlTypes(numTensors, std::vector(numLoops, LevelType::Undef)), + lvlTypes(numTensors, + std::vector(numLoops, LevelFormat::Undef)), loopToLvl(numTensors, std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index ce9c0e39b31b9..62a19c084cac0 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase { MergerTest3T1L() : MergerTestBase(3, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); // Tensor 1: sparse input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed); // Tensor 2: dense output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); } }; @@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase { MergerTest4T1L() : MergerTestBase(4, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); // Tensor 1: sparse input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed); // Tensor 2: sparse input vector - merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed); // Tensor 3: dense output vector - merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); } }; @@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase { MergerTest3T1LD() : MergerTestBase(3, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); // Tensor 1: dense input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); // Tensor 2: dense output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); } }; @@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase { MergerTest4T1LU() : MergerTestBase(4, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: undef input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef); + merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef); // Tensor 1: dense input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); // Tensor 2: undef input vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef); + merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef); // Tensor 3: dense output vector. - merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); } }; @@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase { EXPECT_TRUE(merger.getSynTensorID() == tid(3)); merger.setHasSparseOut(true); // Tensor 0: undef input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef); + merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef); // Tensor 1: undef input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef); + merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef); // Tensor 2: sparse output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed); + merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed); } }; From 0f92b30f92a81331600d2d9cf99d83cde31f3f63 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Feb 2024 20:07:06 +0000 Subject: [PATCH 2/2] address comments --- .../mlir/Dialect/SparseTensor/IR/Enums.h | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 079899a147476..a20a7906189d0 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -163,6 +163,11 @@ enum class LevelFormat : uint64_t { NOutOfM = 0x00100000, }; +template +constexpr bool isAnyOfFmt(LevelFormat fmt) { + return (... || (targets == fmt)); +} + /// Returns string representation of the given level format. constexpr const char *toFormatString(LevelFormat lvlFmt) { switch (lvlFmt) { @@ -218,14 +223,15 @@ struct LevelType { public: /// Check that the `LevelType` contains a valid (possibly undefined) value. static constexpr bool isValidLvlBits(uint64_t lvlBits) { - const uint64_t formatBits = lvlBits & 0xffff0000; + auto fmt = static_cast(lvlBits & 0xffff0000); const uint64_t propertyBits = lvlBits & 0xffff; // If undefined/dense/NOutOfM, then must be unique and ordered. // Otherwise, the format must be one of the known ones. - return (formatBits <= 0x10000 || formatBits == 0x100000) + return (isAnyOfFmt(fmt)) ? (propertyBits == 0) - : (formatBits == 0x20000 || formatBits == 0x40000 || - formatBits == 0x80000); + : (isAnyOfFmt(fmt)); } /// Convert a LevelFormat to its corresponding LevelType with the given @@ -235,6 +241,7 @@ struct LevelType { buildLvlType(LevelFormat lf, const std::vector &properties, uint64_t n = 0, uint64_t m = 0) { + assert((n & 0xff) == n && (m & 0xff) == m); uint64_t newN = n << 32; uint64_t newM = m << 40; uint64_t ltBits = static_cast(lf) | newN | newM; @@ -275,11 +282,13 @@ struct LevelType { LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); } - /// Get N/M of NOutOfM level type. + /// Get N of NOutOfM level type. constexpr uint64_t getN() const { assert(isa()); return (lvlBits >> 32) & 0xff; } + + /// Get M of NOutOfM level type. constexpr uint64_t getM() const { assert(isa()); return (lvlBits >> 40) & 0xff; @@ -389,10 +398,6 @@ inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { } inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); } -// -// Ensure the above methods work as intended. -// - /// Bit manipulations for affine encoding. /// /// Note that because the indices in the mappings refer to dimensions