Skip to content

Allow normal function results of @yield_once coroutines #69843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6067,6 +6067,14 @@ executing the ``begin_apply``) were being "called" by the ``yield``:
or move the value from that position before ending or aborting the
coroutine.

A coroutine optionally may produce normal results. These do not have
``@yields`` annotation in the result type tuple.
::
(%float, %token) = begin_apply %0() : $@yield_once () -> (@yields Float, Int)

Normal results of a coroutine are produced by the corresponding ``end_apply``
instruction.

A ``begin_apply`` must be uniquely either ended or aborted before
exiting the function or looping to an earlier portion of the function.

Expand Down Expand Up @@ -6096,9 +6104,9 @@ end_apply
`````````
::

sil-instruction ::= 'end_apply' sil-value
sil-instruction ::= 'end_apply' sil-value 'as' sil-type

end_apply %token
end_apply %token as $()

Ends the given coroutine activation, which is currently suspended at
a ``yield`` instruction. Transfers control to the coroutine and takes
Expand All @@ -6108,8 +6116,8 @@ when the coroutine reaches a ``return`` instruction.
The operand must always be the token result of a ``begin_apply``
instruction, which is why it need not specify a type.

``end_apply`` currently has no instruction results. If coroutines were
allowed to have normal results, they would be producted by ``end_apply``.
The result of ``end_apply`` is the normal result of the coroutine function (the
operand of the ``return`` instruction)."

When throwing coroutines are supported, there will need to be a
``try_end_apply`` instruction.
Expand Down
41 changes: 22 additions & 19 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4729,24 +4729,27 @@ class SILFunctionType final
using Representation = SILExtInfoBuilder::Representation;

private:
unsigned NumParameters;
unsigned NumParameters = 0;

// These are *normal* results if this is not a coroutine and *yield* results
// otherwise.
unsigned NumAnyResults; // Not including the ErrorResult.
unsigned NumAnyIndirectFormalResults; // Subset of NumAnyResults.
unsigned NumPackResults; // Subset of NumAnyIndirectFormalResults.
// These are *normal* results
unsigned NumAnyResults = 0; // Not including the ErrorResult.
unsigned NumAnyIndirectFormalResults = 0; // Subset of NumAnyResults.
unsigned NumPackResults = 0; // Subset of NumAnyIndirectFormalResults.
// These are *yield* results
unsigned NumAnyYieldResults = 0; // Not including the ErrorResult.
unsigned NumAnyIndirectFormalYieldResults = 0; // Subset of NumAnyYieldResults.
unsigned NumPackYieldResults = 0; // Subset of NumAnyIndirectFormalYieldResults.

// [NOTE: SILFunctionType-layout]
// The layout of a SILFunctionType in memory is:
// SILFunctionType
// SILParameterInfo[NumParameters]
// SILResultInfo[isCoroutine() ? 0 : NumAnyResults]
// SILResultInfo[NumAnyResults]
// SILResultInfo? // if hasErrorResult()
// SILYieldInfo[isCoroutine() ? NumAnyResults : 0]
// SILYieldInfo[NumAnyYieldResults]
// SubstitutionMap[HasPatternSubs + HasInvocationSubs]
// CanType? // if !isCoro && NumAnyResults > 1, formal result cache
// CanType? // if !isCoro && NumAnyResults > 1, all result cache
// CanType? // if NumAnyResults > 1, formal result cache
// CanType? // if NumAnyResults > 1, all result cache

CanGenericSignature InvocationGenericSig;
ProtocolConformanceRef WitnessMethodConformance;
Expand Down Expand Up @@ -4785,7 +4788,7 @@ class SILFunctionType final

/// Do we have slots for caches of the normal-result tuple type?
bool hasResultCache() const {
return NumAnyResults > 1 && !isCoroutine();
return NumAnyResults > 1;
}

CanType &getMutableFormalResultsCache() const {
Expand Down Expand Up @@ -4873,14 +4876,14 @@ class SILFunctionType final
ArrayRef<SILYieldInfo> getYields() const {
return const_cast<SILFunctionType *>(this)->getMutableYields();
}
unsigned getNumYields() const { return isCoroutine() ? NumAnyResults : 0; }
unsigned getNumYields() const { return NumAnyYieldResults; }

/// Return the array of all result information. This may contain inter-mingled
/// direct and indirect results.
ArrayRef<SILResultInfo> getResults() const {
return const_cast<SILFunctionType *>(this)->getMutableResults();
}
unsigned getNumResults() const { return isCoroutine() ? 0 : NumAnyResults; }
unsigned getNumResults() const { return NumAnyResults; }

ArrayRef<SILResultInfo> getResultsWithError() const {
return const_cast<SILFunctionType *>(this)->getMutableResultsWithError();
Expand Down Expand Up @@ -4917,17 +4920,17 @@ class SILFunctionType final
// indirect property, not the SIL indirect property, should be consulted to
// determine whether function reabstraction is necessary.
unsigned getNumIndirectFormalResults() const {
return isCoroutine() ? 0 : NumAnyIndirectFormalResults;
return NumAnyIndirectFormalResults;
}
/// Does this function have any formally indirect results?
bool hasIndirectFormalResults() const {
return getNumIndirectFormalResults() != 0;
}
unsigned getNumDirectFormalResults() const {
return isCoroutine() ? 0 : NumAnyResults - NumAnyIndirectFormalResults;
return NumAnyResults - NumAnyIndirectFormalResults;
}
unsigned getNumPackResults() const {
return isCoroutine() ? 0 : NumPackResults;
return NumPackResults;
}
bool hasIndirectErrorResult() const {
return hasErrorResult() && getErrorResult().isFormalIndirect();
Expand Down Expand Up @@ -4985,17 +4988,17 @@ class SILFunctionType final
TypeExpansionContext expansion);

unsigned getNumIndirectFormalYields() const {
return isCoroutine() ? NumAnyIndirectFormalResults : 0;
return NumAnyIndirectFormalYieldResults;
}
/// Does this function have any formally indirect yields?
bool hasIndirectFormalYields() const {
return getNumIndirectFormalYields() != 0;
}
unsigned getNumDirectFormalYields() const {
return isCoroutine() ? NumAnyResults - NumAnyIndirectFormalResults : 0;
return NumAnyYieldResults - NumAnyIndirectFormalYieldResults;
}
unsigned getNumPackYields() const {
return isCoroutine() ? NumPackResults : 0;
return NumPackYieldResults;
}

struct IndirectFormalYieldFilter {
Expand Down
6 changes: 3 additions & 3 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,11 +576,11 @@ class SILBuilder {
beginApply));
}

EndApplyInst *createEndApply(SILLocation loc, SILValue beginApply) {
EndApplyInst *createEndApply(SILLocation loc, SILValue beginApply, SILType ResultType) {
return insert(new (getModule()) EndApplyInst(getSILDebugLocation(loc),
beginApply));
beginApply, ResultType));
}

BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,
SubstitutionMap Subs,
ArrayRef<SILValue> Args) {
Expand Down
3 changes: 2 additions & 1 deletion include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,8 @@ SILCloner<ImplClass>::visitEndApplyInst(EndApplyInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
recordClonedInstruction(
Inst, getBuilder().createEndApply(getOpLocation(Inst->getLoc()),
getOpValue(Inst->getOperand())));
getOpValue(Inst->getOperand()),
getOpType(Inst->getType())));
}

template<typename ImplClass>
Expand Down
7 changes: 4 additions & 3 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3200,11 +3200,12 @@ class AbortApplyInst
/// normally.
class EndApplyInst
: public UnaryInstructionBase<SILInstructionKind::EndApplyInst,
NonValueInstruction> {
SingleValueInstruction> {
friend SILBuilder;

EndApplyInst(SILDebugLocation debugLoc, SILValue beginApplyToken)
: UnaryInstructionBase(debugLoc, beginApplyToken) {
EndApplyInst(SILDebugLocation debugLoc, SILValue beginApplyToken,
SILType Ty)
: UnaryInstructionBase(debugLoc, beginApplyToken, Ty) {
assert(isaResultOf<BeginApplyInst>(beginApplyToken) &&
isaResultOf<BeginApplyInst>(beginApplyToken)->isBeginApplyToken());
}
Expand Down
4 changes: 2 additions & 2 deletions include/swift/SIL/SILNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
SingleValueInstruction, MayHaveSideEffects, MayRelease)
SINGLE_VALUE_INST(PartialApplyInst, partial_apply,
SingleValueInstruction, MayHaveSideEffects, DoesNotRelease)
SINGLE_VALUE_INST(EndApplyInst, end_apply,
SILInstruction, MayHaveSideEffects, MayRelease)

// Metatypes
SINGLE_VALUE_INST(MetatypeInst, metatype,
Expand Down Expand Up @@ -871,8 +873,6 @@ NON_VALUE_INST(UncheckedRefCastAddrInst, unchecked_ref_cast_addr,
SILInstruction, MayHaveSideEffects, DoesNotRelease)
NON_VALUE_INST(AllocGlobalInst, alloc_global,
SILInstruction, MayHaveSideEffects, DoesNotRelease)
NON_VALUE_INST(EndApplyInst, end_apply,
SILInstruction, MayHaveSideEffects, MayRelease)
NON_VALUE_INST(AbortApplyInst, abort_apply,
SILInstruction, MayHaveSideEffects, MayRelease)
NON_VALUE_INST(PackElementSetInst, pack_element_set,
Expand Down
39 changes: 19 additions & 20 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4611,29 +4611,29 @@ SILFunctionType::SILFunctionType(
!ext.getLifetimeDependenceInfo().empty();
Bits.SILFunctionType.CoroutineKind = unsigned(coroutineKind);
NumParameters = params.size();
if (coroutineKind == SILCoroutineKind::None) {
assert(yields.empty());
NumAnyResults = normalResults.size();
NumAnyIndirectFormalResults = 0;
NumPackResults = 0;
for (auto &resultInfo : normalResults) {
if (resultInfo.isFormalIndirect())
NumAnyIndirectFormalResults++;
if (resultInfo.isPack())
NumPackResults++;
}
memcpy(getMutableResults().data(), normalResults.data(),
normalResults.size() * sizeof(SILResultInfo));
} else {
assert(normalResults.empty());
NumAnyResults = yields.size();
NumAnyIndirectFormalResults = 0;
assert((coroutineKind == SILCoroutineKind::None && yields.empty()) ||
coroutineKind != SILCoroutineKind::None);

NumAnyResults = normalResults.size();
NumAnyIndirectFormalResults = 0;
NumPackResults = 0;
for (auto &resultInfo : normalResults) {
if (resultInfo.isFormalIndirect())
NumAnyIndirectFormalResults++;
if (resultInfo.isPack())
NumPackResults++;
}
memcpy(getMutableResults().data(), normalResults.data(),
normalResults.size() * sizeof(SILResultInfo));
if (coroutineKind != SILCoroutineKind::None) {
NumAnyYieldResults = yields.size();
NumAnyIndirectFormalYieldResults = 0;
NumPackResults = 0;
for (auto &yieldInfo : yields) {
if (yieldInfo.isFormalIndirect())
NumAnyIndirectFormalResults++;
NumAnyIndirectFormalYieldResults++;
if (yieldInfo.isPack())
NumPackResults++;
NumPackYieldResults++;
}
memcpy(getMutableYields().data(), yields.data(),
yields.size() * sizeof(SILYieldInfo));
Expand Down Expand Up @@ -4805,7 +4805,6 @@ CanSILFunctionType SILFunctionType::get(
llvm::Optional<SILResultInfo> errorResult, SubstitutionMap patternSubs,
SubstitutionMap invocationSubs, const ASTContext &ctx,
ProtocolConformanceRef witnessMethodConformance) {
assert(coroutineKind == SILCoroutineKind::None || normalResults.empty());
assert(coroutineKind != SILCoroutineKind::None || yields.empty());
assert(!ext.isPseudogeneric() || genericSig ||
coroutineKind != SILCoroutineKind::None);
Expand Down
69 changes: 63 additions & 6 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,24 +656,34 @@ namespace {
}

void SignatureExpansion::expandCoroutineResult(bool forContinuation) {
assert(FnType->getNumResults() == 0 &&
"having both normal and yield results is currently unsupported");

// The return type may be different for the ramp function vs. the
// continuations.
if (forContinuation) {
switch (FnType->getCoroutineKind()) {
case SILCoroutineKind::None:
llvm_unreachable("should have been filtered out before here");

// Yield-once coroutines just return void from the continuation.
case SILCoroutineKind::YieldOnce:
ResultIRType = IGM.VoidTy;
// Yield-once coroutines may optionaly return a value from the continuation.
case SILCoroutineKind::YieldOnce: {
auto fnConv = getSILFuncConventions();

assert(fnConv.getNumIndirectSILResults() == 0);
// Ensure that no parameters were added before to correctly record their ABI
// details.
assert(ParamIRTypes.empty());

// Expand the direct result.
const TypeInfo *directResultTypeInfo;
std::tie(ResultIRType, directResultTypeInfo) = expandDirectResult();

return;
}

// Yield-many coroutines yield the same types from the continuation
// as they do from the ramp function.
case SILCoroutineKind::YieldMany:
assert(FnType->getNumResults() == 0 &&
"having both normal and yield results is currently unsupported");
break;
}
}
Expand Down Expand Up @@ -5803,6 +5813,53 @@ void irgen::emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &asyncLayout,
emitAsyncReturn(IGF, asyncLayout, fnType, nativeResults);
}

void irgen::emitYieldOnceCoroutineResult(IRGenFunction &IGF, Explosion &result,
SILType funcResultType, SILType returnResultType) {
auto &Builder = IGF.Builder;
auto &IGM = IGF.IGM;

// Create coroutine exit block and branch to it.
auto coroEndBB = IGF.createBasicBlock("coro.end.normal");
IGF.setCoroutineExitBlock(coroEndBB);
Builder.CreateBr(coroEndBB);

// Emit the block.
Builder.emitBlock(coroEndBB);
auto handle = IGF.getCoroutineHandle();

llvm::Value *resultToken = nullptr;
if (result.empty()) {
assert(IGM.getTypeInfo(returnResultType)
.nativeReturnValueSchema(IGM)
.empty() &&
"Empty explosion must match the native calling convention");
// No results: just use none token
resultToken = llvm::ConstantTokenNone::get(Builder.getContext());
} else {
// Capture results via `coro_end_results` intrinsic
result = IGF.coerceValueTo(returnResultType, result, funcResultType);
auto &nativeSchema =
IGM.getTypeInfo(funcResultType).nativeReturnValueSchema(IGM);
assert(!nativeSchema.requiresIndirect());

Explosion native = nativeSchema.mapIntoNative(IGM, IGF, result,
funcResultType,
false /* isOutlined */);
SmallVector<llvm::Value *, 1> args;
for (unsigned i = 0, e = native.size(); i != e; ++i)
args.push_back(native.claimNext());

resultToken =
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end_results, args);
}

Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end,
{handle,
/*is unwind*/ Builder.getFalse(),
resultToken});
Builder.CreateUnreachable();
}

FunctionPointer
IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
auto *fnTy = llvm::FunctionType::get(
Expand Down
2 changes: 2 additions & 0 deletions lib/IRGen/GenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ namespace irgen {
SILType funcResultTypeInContext,
CanSILFunctionType fnType, Explosion &result,
Explosion &error);
void emitYieldOnceCoroutineResult(IRGenFunction &IGF, Explosion &result,
SILType funcResultType, SILType returnResultType);

Address emitAutoDiffCreateLinearMapContextWithType(
IRGenFunction &IGF, llvm::Value *topLevelSubcontextMetatype);
Expand Down
2 changes: 1 addition & 1 deletion lib/IRGen/IRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ void IRGenFunction::emitAwaitAsyncContinuation(
// because the continuation result is not available yet. When the
// continuation is later resumed, the task will get scheduled
// starting from the suspension point.
emitCoroutineOrAsyncExit();
emitCoroutineOrAsyncExit(false);
}

Builder.emitBlock(contBB);
Expand Down
12 changes: 11 additions & 1 deletion lib/IRGen/IRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ class IRGenFunction {
CoroutineHandle = handle;
}

llvm::BasicBlock *getCoroutineExitBlock() const {
return CoroutineExitBlock;
}

void setCoroutineExitBlock(llvm::BasicBlock *block) {
assert(CoroutineExitBlock == nullptr && "already set exit BB");
assert(block != nullptr && "setting a null exit BB");
CoroutineExitBlock = block;
}

llvm::Value *getAsyncTask();
llvm::Value *getAsyncContext();
void storeCurrentAsyncContext(llvm::Value *context);
Expand Down Expand Up @@ -236,7 +246,7 @@ class IRGenFunction {
bool callsAnyAlwaysInlineThunksWithForeignExceptionTraps = false;

public:
void emitCoroutineOrAsyncExit();
void emitCoroutineOrAsyncExit(bool isUnwind);

//--- Helper methods -----------------------------------------------------------
public:
Expand Down
Loading