Skip to content

Commit 0f3c430

Browse files
committed
IRGen: Refactor function type metadata emission
1 parent 92ad172 commit 0f3c430

File tree

1 file changed

+189
-125
lines changed

1 file changed

+189
-125
lines changed

lib/IRGen/MetadataRequest.cpp

Lines changed: 189 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,28 +1370,19 @@ static llvm::Value *getFunctionParameterRef(IRGenFunction &IGF,
13701370
return IGF.emitAbstractTypeMetadataRef(type);
13711371
}
13721372

1373-
static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
1374-
CanFunctionType type,
1375-
DynamicMetadataRequest request) {
1376-
auto result =
1377-
IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType());
1378-
1379-
auto params = type.getParams();
1380-
auto numParams = params.size();
1381-
1382-
// Retrieve the ABI parameter flags from the type-level parameter
1383-
// flags.
1384-
auto getABIParameterFlags = [](ParameterTypeFlags flags) {
1385-
return ParameterFlags()
1373+
/// Mapping type-level parameter flags to ABI parameter flags.
1374+
static ParameterFlags getABIParameterFlags(ParameterTypeFlags flags) {
1375+
return ParameterFlags()
13861376
.withValueOwnership(flags.getValueOwnership())
13871377
.withVariadic(flags.isVariadic())
13881378
.withAutoClosure(flags.isAutoClosure())
13891379
.withNoDerivative(flags.isNoDerivative())
13901380
.withIsolated(flags.isIsolated());
1391-
};
1381+
}
13921382

1383+
static FunctionTypeFlags getFunctionTypeFlags(CanFunctionType type) {
13931384
bool hasParameterFlags = false;
1394-
for (auto param : params) {
1385+
for (auto param : type.getParams()) {
13951386
if (!getABIParameterFlags(param.getParameterFlags()).isNone()) {
13961387
hasParameterFlags = true;
13971388
break;
@@ -1417,80 +1408,149 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
14171408
break;
14181409
}
14191410

1420-
FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind;
1421-
switch (type->getDifferentiabilityKind()) {
1422-
case DifferentiabilityKind::NonDifferentiable:
1423-
metadataDifferentiabilityKind =
1424-
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
1425-
break;
1426-
case DifferentiabilityKind::Normal:
1427-
metadataDifferentiabilityKind =
1428-
FunctionMetadataDifferentiabilityKind::Normal;
1429-
break;
1430-
case DifferentiabilityKind::Linear:
1431-
metadataDifferentiabilityKind =
1432-
FunctionMetadataDifferentiabilityKind::Linear;
1433-
break;
1434-
case DifferentiabilityKind::Forward:
1435-
metadataDifferentiabilityKind =
1436-
FunctionMetadataDifferentiabilityKind::Forward;
1437-
break;
1438-
case DifferentiabilityKind::Reverse:
1439-
metadataDifferentiabilityKind =
1440-
FunctionMetadataDifferentiabilityKind::Reverse;
1441-
break;
1411+
return FunctionTypeFlags()
1412+
.withConvention(metadataConvention)
1413+
.withAsync(type->isAsync())
1414+
.withConcurrent(type->isSendable())
1415+
.withThrows(type->isThrowing())
1416+
.withParameterFlags(hasParameterFlags)
1417+
.withEscaping(isEscaping)
1418+
.withDifferentiable(type->isDifferentiable())
1419+
.withGlobalActor(!type->getGlobalActor().isNull());
1420+
}
1421+
1422+
namespace {
1423+
struct FunctionTypeMetadataParamInfo {
1424+
StackAddress parameters;
1425+
StackAddress paramFlags;
1426+
unsigned numParams;
1427+
};
1428+
}
1429+
1430+
static FunctionTypeMetadataParamInfo
1431+
emitFunctionTypeMetadataParams(IRGenFunction &IGF,
1432+
AnyFunctionType::CanParamArrayRef params,
1433+
FunctionTypeFlags flags,
1434+
DynamicMetadataRequest request,
1435+
SmallVectorImpl<llvm::Value *> &arguments) {
1436+
FunctionTypeMetadataParamInfo info;
1437+
info.numParams = params.size();
1438+
1439+
ConstantInitBuilder paramFlags(IGF.IGM);
1440+
auto flagsArr = paramFlags.beginArray();
1441+
1442+
if (!params.empty()) {
1443+
auto arrayTy =
1444+
llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, info.numParams);
1445+
info.parameters = StackAddress(IGF.createAlloca(
1446+
arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters"));
1447+
1448+
IGF.Builder.CreateLifetimeStart(info.parameters.getAddress(),
1449+
IGF.IGM.getPointerSize() * info.numParams);
1450+
1451+
for (unsigned i : indices(params)) {
1452+
auto param = params[i];
1453+
auto paramFlags = getABIParameterFlags(param.getParameterFlags());
1454+
1455+
auto argPtr = IGF.Builder.CreateStructGEP(info.parameters.getAddress(), i,
1456+
IGF.IGM.getPointerSize());
1457+
auto *typeRef = getFunctionParameterRef(IGF, param);
1458+
IGF.Builder.CreateStore(typeRef, argPtr);
1459+
if (i == 0)
1460+
arguments.push_back(argPtr.getAddress());
1461+
1462+
flagsArr.addInt32(paramFlags.getIntValue());
1463+
}
1464+
} else {
1465+
auto parametersPtr =
1466+
llvm::ConstantPointerNull::get(
1467+
IGF.IGM.TypeMetadataPtrTy->getPointerTo());
1468+
arguments.push_back(parametersPtr);
1469+
}
1470+
1471+
auto *Int32Ptr = IGF.IGM.Int32Ty->getPointerTo();
1472+
if (flags.hasParameterFlags()) {
1473+
auto *flagsVar = flagsArr.finishAndCreateGlobal(
1474+
"parameter-flags", IGF.IGM.getPointerAlignment(),
1475+
/* constant */ true);
1476+
arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr));
1477+
} else {
1478+
flagsArr.abandon();
1479+
arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr));
14421480
}
14431481

1444-
auto flags = FunctionTypeFlags()
1445-
.withNumParameters(numParams)
1446-
.withConvention(metadataConvention)
1447-
.withAsync(type->isAsync())
1448-
.withConcurrent(type->isSendable())
1449-
.withThrows(type->isThrowing())
1450-
.withParameterFlags(hasParameterFlags)
1451-
.withEscaping(isEscaping)
1452-
.withDifferentiable(type->isDifferentiable())
1453-
.withGlobalActor(!type->getGlobalActor().isNull());
1454-
1455-
auto flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy,
1456-
flags.getIntValue());
1457-
llvm::Value *diffKindVal = nullptr;
1458-
if (type->isDifferentiable()) {
1459-
assert(metadataDifferentiabilityKind.isDifferentiable());
1460-
diffKindVal = llvm::ConstantInt::get(
1461-
IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue());
1462-
} else if (type->getGlobalActor()) {
1463-
diffKindVal = llvm::ConstantInt::get(
1464-
IGF.IGM.SizeTy,
1465-
FunctionMetadataDifferentiabilityKind::NonDifferentiable);
1466-
}
1467-
1468-
auto collectParameters =
1469-
[&](llvm::function_ref<void(unsigned, llvm::Value *,
1470-
ParameterFlags flags)>
1471-
processor) {
1472-
for (auto index : indices(params)) {
1473-
auto param = params[index];
1474-
auto flags = param.getParameterFlags();
1475-
1476-
auto parameterFlags = getABIParameterFlags(flags);
1477-
processor(index, getFunctionParameterRef(IGF, param),
1478-
parameterFlags);
1479-
}
1480-
};
1482+
return info;
1483+
}
1484+
1485+
static FunctionTypeMetadataParamInfo
1486+
emitDynamicFunctionTypeMetadataParams(IRGenFunction &IGF,
1487+
AnyFunctionType::CanParamArrayRef params,
1488+
FunctionTypeFlags flags,
1489+
CanPackType packType,
1490+
DynamicMetadataRequest request,
1491+
SmallVectorImpl<llvm::Value *> &arguments) {
1492+
assert(false);
1493+
}
1494+
1495+
static void cleanupFunctionTypeMetadataParams(IRGenFunction &IGF,
1496+
FunctionTypeMetadataParamInfo info) {
1497+
if (info.parameters.isValid()) {
1498+
if (info.parameters.getExtraInfo()) {
1499+
IGF.emitDeallocateDynamicAlloca(info.parameters);
1500+
} else {
1501+
IGF.Builder.CreateLifetimeEnd(info.parameters.getAddress(),
1502+
IGF.IGM.getPointerSize() * info.numParams);
1503+
}
1504+
}
1505+
}
1506+
1507+
static CanPackType getInducedPackType(AnyFunctionType::CanParamArrayRef params,
1508+
ASTContext &ctx) {
1509+
SmallVector<CanType, 2> elts;
1510+
for (auto param : params)
1511+
elts.push_back(param.getPlainType());
1512+
1513+
return CanPackType::get(ctx, elts);
1514+
}
1515+
1516+
static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
1517+
CanFunctionType type,
1518+
DynamicMetadataRequest request) {
1519+
auto result =
1520+
IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType());
1521+
1522+
auto params = type.getParams();
1523+
bool hasPackExpansion = type->containsPackExpansionParam();
1524+
1525+
auto flags = getFunctionTypeFlags(type);
1526+
llvm::Value *flagsVal = nullptr;
1527+
llvm::Value *shapeExpression = nullptr;
1528+
CanPackType packType;
1529+
1530+
if (!hasPackExpansion) {
1531+
flags = flags.withNumParameters(params.size());
1532+
flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy,
1533+
flags.getIntValue());
1534+
} else {
1535+
packType = getInducedPackType(type.getParams(), type->getASTContext());
1536+
auto *shapeExpression = IGF.emitPackShapeExpression(packType);
1537+
1538+
flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy,
1539+
flags.getIntValue());
1540+
flagsVal = IGF.Builder.CreateOr(flagsVal, shapeExpression);
1541+
}
14811542

14821543
auto constructSimpleCall =
14831544
[&](llvm::SmallVectorImpl<llvm::Value *> &arguments)
14841545
-> FunctionPointer {
1546+
assert(!flags.hasParameterFlags());
1547+
assert(!shapeExpression);
1548+
14851549
arguments.push_back(flagsVal);
14861550

1487-
collectParameters([&](unsigned i, llvm::Value *typeRef,
1488-
ParameterFlags flags) {
1489-
arguments.push_back(typeRef);
1490-
if (hasParameterFlags)
1491-
arguments.push_back(
1492-
llvm::ConstantInt::get(IGF.IGM.Int32Ty, flags.getIntValue()));
1493-
});
1551+
for (auto param : params) {
1552+
arguments.push_back(getFunctionParameterRef(IGF, param));
1553+
}
14941554

14951555
arguments.push_back(result);
14961556

@@ -1512,13 +1572,13 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
15121572
}
15131573
};
15141574

1515-
switch (numParams) {
1575+
switch (params.size()) {
15161576
case 0:
15171577
case 1:
15181578
case 2:
15191579
case 3: {
1520-
if (!hasParameterFlags && !type->isDifferentiable() &&
1521-
!type->getGlobalActor()) {
1580+
if (!flags.hasParameterFlags() && !type->isDifferentiable() &&
1581+
!type->getGlobalActor() && !hasPackExpansion) {
15221582
llvm::SmallVector<llvm::Value *, 8> arguments;
15231583
auto metadataFn = constructSimpleCall(arguments);
15241584
auto *call = IGF.Builder.CreateCall(metadataFn, arguments);
@@ -1537,54 +1597,60 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
15371597
"0 parameter case should be specialized unless it is a "
15381598
"differentiable function or has a global actor");
15391599

1540-
auto *const Int32Ptr = IGF.IGM.Int32Ty->getPointerTo();
15411600
llvm::SmallVector<llvm::Value *, 8> arguments;
15421601

15431602
arguments.push_back(flagsVal);
15441603

1545-
if (diffKindVal) {
1546-
arguments.push_back(diffKindVal);
1547-
}
1548-
1549-
ConstantInitBuilder paramFlags(IGF.IGM);
1550-
auto flagsArr = paramFlags.beginArray();
1604+
llvm::Value *diffKindVal = nullptr;
15511605

1552-
Address parameters;
1553-
if (!params.empty()) {
1554-
auto arrayTy =
1555-
llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, numParams);
1556-
parameters = IGF.createAlloca(
1557-
arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters");
1558-
1559-
IGF.Builder.CreateLifetimeStart(parameters,
1560-
IGF.IGM.getPointerSize() * numParams);
1606+
{
1607+
FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind;
1608+
switch (type->getDifferentiabilityKind()) {
1609+
case DifferentiabilityKind::NonDifferentiable:
1610+
metadataDifferentiabilityKind =
1611+
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
1612+
break;
1613+
case DifferentiabilityKind::Normal:
1614+
metadataDifferentiabilityKind =
1615+
FunctionMetadataDifferentiabilityKind::Normal;
1616+
break;
1617+
case DifferentiabilityKind::Linear:
1618+
metadataDifferentiabilityKind =
1619+
FunctionMetadataDifferentiabilityKind::Linear;
1620+
break;
1621+
case DifferentiabilityKind::Forward:
1622+
metadataDifferentiabilityKind =
1623+
FunctionMetadataDifferentiabilityKind::Forward;
1624+
break;
1625+
case DifferentiabilityKind::Reverse:
1626+
metadataDifferentiabilityKind =
1627+
FunctionMetadataDifferentiabilityKind::Reverse;
1628+
break;
1629+
}
15611630

1562-
collectParameters([&](unsigned i, llvm::Value *typeRef,
1563-
ParameterFlags flags) {
1564-
auto argPtr = IGF.Builder.CreateStructGEP(parameters, i,
1565-
IGF.IGM.getPointerSize());
1566-
IGF.Builder.CreateStore(typeRef, argPtr);
1567-
if (i == 0)
1568-
arguments.push_back(argPtr.getAddress());
1631+
if (type->isDifferentiable()) {
1632+
assert(metadataDifferentiabilityKind.isDifferentiable());
1633+
diffKindVal = llvm::ConstantInt::get(
1634+
IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue());
1635+
} else if (type->getGlobalActor()) {
1636+
diffKindVal = llvm::ConstantInt::get(
1637+
IGF.IGM.SizeTy,
1638+
FunctionMetadataDifferentiabilityKind::NonDifferentiable);
1639+
}
1640+
}
15691641

1570-
if (hasParameterFlags)
1571-
flagsArr.addInt32(flags.getIntValue());
1572-
});
1573-
} else {
1574-
auto parametersPtr =
1575-
llvm::ConstantPointerNull::get(
1576-
IGF.IGM.TypeMetadataPtrTy->getPointerTo());
1577-
arguments.push_back(parametersPtr);
1642+
if (diffKindVal) {
1643+
arguments.push_back(diffKindVal);
15781644
}
15791645

1580-
if (hasParameterFlags) {
1581-
auto *flagsVar = flagsArr.finishAndCreateGlobal(
1582-
"parameter-flags", IGF.IGM.getPointerAlignment(),
1583-
/* constant */ true);
1584-
arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr));
1646+
FunctionTypeMetadataParamInfo info;
1647+
if (!hasPackExpansion) {
1648+
assert(!shapeExpression);
1649+
info = emitFunctionTypeMetadataParams(IGF, params, flags, request,
1650+
arguments);
15851651
} else {
1586-
flagsArr.abandon();
1587-
arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr));
1652+
info = emitDynamicFunctionTypeMetadataParams(IGF, params, flags, packType,
1653+
request, arguments);
15881654
}
15891655

15901656
arguments.push_back(result);
@@ -1608,9 +1674,7 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
16081674
auto call = IGF.Builder.CreateCall(getMetadataFn, arguments);
16091675
call->setDoesNotThrow();
16101676

1611-
if (parameters.isValid())
1612-
IGF.Builder.CreateLifetimeEnd(parameters,
1613-
IGF.IGM.getPointerSize() * numParams);
1677+
cleanupFunctionTypeMetadataParams(IGF, info);
16141678

16151679
return MetadataResponse::forComplete(call);
16161680
}

0 commit comments

Comments
 (0)