Skip to content

Commit a83b190

Browse files
[AArch64] Refactor implementation of FP8 types (NFC)
* The FP8 scalar type (`__mfp8`) was described as a vector type * The FP8 vector types were described/assumed to have integer element type (the element type ought to be `__mfp8`), * Add support for `m` type specifier (denoting `__mfp8`) in `DecodeTypeFromStr` and create SVE builtin prototypes using the specifier, instead of `int8_t`.
1 parent 7f4414b commit a83b190

File tree

8 files changed

+76
-23
lines changed

8 files changed

+76
-23
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25182518
bool isFloat32Type() const;
25192519
bool isDoubleType() const;
25202520
bool isBFloat16Type() const;
2521+
bool isMFloat8Type() const;
25212522
bool isFloat128Type() const;
25222523
bool isIbm128Type() const;
25232524
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8532,6 +8533,10 @@ inline bool Type::isBFloat16Type() const {
85328533
return isSpecificBuiltinType(BuiltinType::BFloat16);
85338534
}
85348535

8536+
inline bool Type::isMFloat8Type() const {
8537+
return isSpecificBuiltinType(BuiltinType::MFloat8);
8538+
}
8539+
85358540
inline bool Type::isFloat128Type() const {
85368541
return isSpecificBuiltinType(BuiltinType::Float128);
85378542
}

clang/include/clang/Basic/AArch64SVEACLETypes.def

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
// - IsBF true for vector of brain float elements.
5858
//===----------------------------------------------------------------------===//
5959

60+
#ifndef SVE_SCALAR_TYPE
61+
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
62+
SVE_TYPE(Name, Id, SingletonId)
63+
#endif
64+
6065
#ifndef SVE_VECTOR_TYPE
6166
#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
6267
SVE_TYPE(Name, Id, SingletonId)
@@ -72,6 +77,11 @@
7277
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
7378
#endif
7479

80+
#ifndef SVE_VECTOR_TYPE_MFLOAT
81+
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
82+
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, false)
83+
#endif
84+
7585
#ifndef SVE_VECTOR_TYPE_FLOAT
7686
#define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
7787
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
@@ -125,8 +135,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", SveFloat64, SveFloat64Ty
125135

126136
SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, SveBFloat16Ty, 8, 16, 1)
127137

128-
// This is a 8 bits opaque type.
129-
SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1, false)
138+
SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1)
130139

131140
//
132141
// x2
@@ -148,7 +157,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", "svfloat64x2_t", SveFloat64x2, Sv
148157

149158
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)
150159

151-
SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2, false)
160+
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2)
152161

153162
//
154163
// x3
@@ -170,7 +179,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", "svfloat64x3_t", SveFloat64x3, Sv
170179

171180
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)
172181

173-
SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3, false)
182+
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3)
174183

175184
//
176185
// x4
@@ -192,19 +201,21 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", "svfloat64x4_t", SveFloat64x4, Sv
192201

193202
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)
194203

195-
SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4, false)
204+
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4)
196205

197206
SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
198207
SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, SveBoolx2Ty, 16, 2)
199208
SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4Ty, 16, 4)
200209

201210
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
202211

203-
AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
212+
SVE_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)
213+
204214
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
205215
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
206216

207217
#undef SVE_VECTOR_TYPE
218+
#undef SVE_VECTOR_TYPE_MFLOAT
208219
#undef SVE_VECTOR_TYPE_BFLOAT
209220
#undef SVE_VECTOR_TYPE_FLOAT
210221
#undef SVE_VECTOR_TYPE_INT
@@ -213,4 +224,5 @@ AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloa
213224
#undef SVE_OPAQUE_TYPE
214225
#undef AARCH64_VECTOR_TYPE_MFLOAT
215226
#undef AARCH64_VECTOR_TYPE
227+
#undef SVE_SCALAR_TYPE
216228
#undef SVE_TYPE

clang/lib/AST/ASTContext.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,6 +2254,11 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
22542254
Width = NumEls * ElBits * NF; \
22552255
Align = NumEls * ElBits; \
22562256
break;
2257+
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
2258+
case BuiltinType::Id: \
2259+
Width = Bits; \
2260+
Align = Bits; \
2261+
break;
22572262
#include "clang/Basic/AArch64SVEACLETypes.def"
22582263
#define PPC_VECTOR_TYPE(Name, Id, Size) \
22592264
case BuiltinType::Id: \
@@ -4374,15 +4379,18 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
43744379
ElBits, NF) \
43754380
case BuiltinType::Id: \
43764381
return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
4382+
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
4383+
ElBits, NF) \
4384+
case BuiltinType::Id: \
4385+
return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
43774386
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
43784387
case BuiltinType::Id: \
43794388
return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
43804389
#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
43814390
ElBits, NF) \
43824391
case BuiltinType::Id: \
4383-
return {getIntTypeForBitwidth(ElBits, false), \
4384-
llvm::ElementCount::getFixed(NumEls), NF};
4385-
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
4392+
return {MFloat8Ty, llvm::ElementCount::getFixed(NumEls), NF};
4393+
#define SVE_TYPE(Name, Id, SingletonId)
43864394
#include "clang/Basic/AArch64SVEACLETypes.def"
43874395

43884396
#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
@@ -4444,11 +4452,16 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
44444452
EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) { \
44454453
return SingletonId; \
44464454
}
4455+
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
4456+
ElBits, NF) \
4457+
if (EltTy->isMFloat8Type() && EltTySize == ElBits && \
4458+
NumElts == (NumEls * NF) && NumFields == 1) { \
4459+
return SingletonId; \
4460+
}
44474461
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
44484462
if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1) \
44494463
return SingletonId;
4450-
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
4451-
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
4464+
#define SVE_TYPE(Name, Id, SingletonId)
44524465
#include "clang/Basic/AArch64SVEACLETypes.def"
44534466
} else if (Target->hasRISCVVTypes()) {
44544467
uint64_t EltTySize = getTypeSize(EltTy);
@@ -12153,8 +12166,15 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
1215312166
RequiresICE, false);
1215412167
assert(!RequiresICE && "Can't require vector ICE");
1215512168

12156-
// TODO: No way to make AltiVec vectors in builtins yet.
12157-
Type = Context.getVectorType(ElementType, NumElements, VectorKind::Generic);
12169+
if (ElementType == Context.MFloat8Ty) {
12170+
assert((NumElements == 8 || NumElements == 16) &&
12171+
"Invalid number of elements");
12172+
Type = NumElements == 8 ? Context.MFloat8x8Ty : Context.MFloat8x16Ty;
12173+
} else {
12174+
// TODO: No way to make AltiVec vectors in builtins yet.
12175+
Type =
12176+
Context.getVectorType(ElementType, NumElements, VectorKind::Generic);
12177+
}
1215812178
break;
1215912179
}
1216012180
case 'E': {
@@ -12210,6 +12230,9 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
1221012230
case 'p':
1221112231
Type = Context.getProcessIDType();
1221212232
break;
12233+
case 'm':
12234+
Type = Context.MFloat8Ty;
12235+
break;
1221312236
}
1221412237

1221512238
// If there are modifiers and if we're allowed to parse them, go for it.

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3438,6 +3438,11 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
34383438
type_name = MangledName; \
34393439
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
34403440
break;
3441+
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
3442+
case BuiltinType::Id: \
3443+
type_name = MangledName; \
3444+
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
3445+
break;
34413446
#include "clang/Basic/AArch64SVEACLETypes.def"
34423447
#define PPC_VECTOR_TYPE(Name, Id, Size) \
34433448
case BuiltinType::Id: \

clang/lib/AST/Type.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
25272527
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
25282528
case BuiltinType::Id: \
25292529
return true;
2530-
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
2531-
case BuiltinType::Id: \
2532-
return false;
2530+
#define SVE_TYPE(Name, Id, SingletonId)
25332531
#include "clang/Basic/AArch64SVEACLETypes.def"
25342532
default:
25352533
return false;

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,13 +507,15 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
507507
case BuiltinType::Id:
508508
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
509509
case BuiltinType::Id:
510-
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
510+
#define SVE_TYPE(Name, Id, SingletonId)
511511
#include "clang/Basic/AArch64SVEACLETypes.def"
512512
{
513513
ASTContext::BuiltinVectorTypeInfo Info =
514514
Context.getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
515-
auto VTy =
516-
llvm::VectorType::get(ConvertType(Info.ElementType), Info.EC);
515+
auto *EltTy = Info.ElementType->isMFloat8Type()
516+
? llvm::Type::getInt8Ty(getLLVMContext())
517+
: ConvertType(Info.ElementType);
518+
auto *VTy = llvm::VectorType::get(EltTy, Info.EC);
517519
switch (Info.NumVectors) {
518520
default:
519521
llvm_unreachable("Expected 1, 2, 3 or 4 vectors!");
@@ -529,6 +531,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
529531
}
530532
case BuiltinType::SveCount:
531533
return llvm::TargetExtType::get(getLLVMContext(), "aarch64.svcount");
534+
case BuiltinType::MFloat8:
535+
return llvm::VectorType::get(llvm::Type::getInt8Ty(getLLVMContext()), 1,
536+
false);
532537
#define PPC_VECTOR_TYPE(Name, Id, Size) \
533538
case BuiltinType::Id: \
534539
ResultType = \
@@ -650,6 +655,8 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
650655
// An ext_vector_type of Bool is really a vector of bits.
651656
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
652657
? llvm::Type::getInt1Ty(getLLVMContext())
658+
: VT->getElementType()->isMFloat8Type()
659+
? llvm::Type::getInt8Ty(getLLVMContext())
653660
: ConvertType(VT->getElementType());
654661
ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
655662
break;

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
243243

244244
case BuiltinType::SChar:
245245
case BuiltinType::UChar:
246+
case BuiltinType::MFloat8:
246247
return llvm::ScalableVectorType::get(
247248
llvm::Type::getInt8Ty(getVMContext()), 16);
248249

@@ -761,8 +762,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
761762
getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
762763
assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
763764
"Expected 1, 2, 3 or 4 vectors!");
764-
auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
765-
Info.EC.getKnownMinValue());
765+
llvm::Type *EltTy = Info.ElementType->isMFloat8Type()
766+
? llvm::Type::getInt8Ty(getVMContext())
767+
: CGT.ConvertType(Info.ElementType);
768+
auto *VTy = llvm::ScalableVectorType::get(EltTy, Info.EC.getKnownMinValue());
766769

767770
if (CoerceToSeq.size() + Info.NumVectors > 12)
768771
return false;

clang/utils/TableGen/SveEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,15 @@ std::string SVEType::builtinBaseType() const {
448448
case TypeKind::PredicatePattern:
449449
return "i";
450450
case TypeKind::Fpm:
451-
return "Wi";
451+
return "UWi";
452452
case TypeKind::Predicate:
453453
return "b";
454454
case TypeKind::BFloat16:
455455
assert(ElementBitwidth == 16 && "Invalid BFloat16!");
456456
return "y";
457457
case TypeKind::MFloat8:
458458
assert(ElementBitwidth == 8 && "Invalid MFloat8!");
459-
return "c";
459+
return "m";
460460
case TypeKind::Float:
461461
switch (ElementBitwidth) {
462462
case 16:

0 commit comments

Comments
 (0)