From 6eefd963c6e190abdf1df2f2d26d068ed0a761ab Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Thu, 26 Sep 2024 21:52:17 -0700 Subject: [PATCH] Fix partial apply forwarder emission for coroutines that are methods of structs with type parameters. Simplify the code while here --- lib/IRGen/GenFunc.cpp | 77 ++++++------------- .../validation-test/modify_accessor.swift | 34 ++++++++ 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index 1e484b7362125..e0beff8a89f7b 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -1120,11 +1120,11 @@ class PartialApplicationForwarderEmission { virtual void addDynamicFunctionContext(Explosion &explosion) = 0; virtual void addDynamicFunctionPointer(Explosion &explosion) = 0; - virtual void addSelf(Explosion &explosion) { addArgument(explosion); } - virtual void addWitnessSelfMetadata(llvm::Value *value) { + void addSelf(Explosion &explosion) { addArgument(explosion); } + void addWitnessSelfMetadata(llvm::Value *value) { addArgument(value); } - virtual void addWitnessSelfWitnessTable(llvm::Value *value) { + void addWitnessSelfWitnessTable(llvm::Value *value) { addArgument(value); } virtual void forwardErrorResult() = 0; @@ -1412,12 +1412,6 @@ class CoroPartialApplicationForwarderEmission : public PartialApplicationForwarderEmission { using super = PartialApplicationForwarderEmission; -private: - llvm::Value *Self; - llvm::Value *FirstData; - llvm::Value *SecondData; - WitnessMetadata Witness; - public: CoroPartialApplicationForwarderEmission( IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd, @@ -1428,8 +1422,7 @@ class CoroPartialApplicationForwarderEmission ArrayRef conventions) : PartialApplicationForwarderEmission( IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType, - substType, outType, subs, layout, conventions), - Self(nullptr), FirstData(nullptr), SecondData(nullptr) {} + substType, outType, subs, layout, conventions) {} void begin() override { auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule()); @@ -1473,41 +1466,13 @@ class CoroPartialApplicationForwarderEmission void gatherArgumentsFromApply() override { super::gatherArgumentsFromApply(false); } - llvm::Value *getDynamicFunctionPointer() override { - llvm::Value *Ret = SecondData; - SecondData = nullptr; - return Ret; - } - llvm::Value *getDynamicFunctionContext() override { - llvm::Value *Ret = FirstData; - FirstData = nullptr; - return Ret; - } + llvm::Value *getDynamicFunctionPointer() override { return args.takeLast(); } + llvm::Value *getDynamicFunctionContext() override { return args.takeLast(); } void addDynamicFunctionContext(Explosion &explosion) override { - assert(!Self && "context value overrides 'self'"); - FirstData = explosion.claimNext(); + addArgument(explosion); } void addDynamicFunctionPointer(Explosion &explosion) override { - SecondData = explosion.claimNext(); - } - void addSelf(Explosion &explosion) override { - assert(!FirstData && "'self' overrides another context value"); - if (!hasSelfContextParameter(origType)) { - // witness methods can be declared on types that are not classes. Pass - // such "self" argument as a plain argument. - addArgument(explosion); - return; - } - Self = explosion.claimNext(); - FirstData = Self; - } - - void addWitnessSelfMetadata(llvm::Value *value) override { - Witness.SelfMetadata = value; - } - - void addWitnessSelfWitnessTable(llvm::Value *value) override { - Witness.SelfWitnessTable = value; + addArgument(explosion); } void forwardErrorResult() override { @@ -1528,13 +1493,26 @@ class CoroPartialApplicationForwarderEmission } Explosion callCoroutine(FunctionPointer &fnPtr) { - Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData); + bool isWitnessMethodCallee = origType->getRepresentation() == + SILFunctionTypeRepresentation::WitnessMethod; + + WitnessMetadata witnessMetadata; + if (isWitnessMethodCallee) { + witnessMetadata.SelfWitnessTable = args.takeLast(); + witnessMetadata.SelfMetadata = args.takeLast(); + } + + llvm::Value *selfValue = nullptr; + if (calleeHasContext || hasSelfContextParameter(origType)) + selfValue = args.takeLast(); + + Callee callee({origType, substType, subs}, fnPtr, selfValue); std::unique_ptr emitSuspend = - getCallEmission(subIGF, Self, std::move(callee)); + getCallEmission(subIGF, callee.getSwiftContext(), std::move(callee)); emitSuspend->begin(); - emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness); + emitSuspend->setArgs(args, /*isOutlined=*/false, &witnessMetadata); Explosion yieldedValues; emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false); emitSuspend->end(); @@ -1940,12 +1918,7 @@ static llvm::Value *emitPartialApplicationForwarder( } else { argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy); } - if (haveContextArgument) { - Explosion e; - e.add(argValue); - emission->addDynamicFunctionContext(e); - } else - emission->addArgument(argValue); + emission->addArgument(argValue); // If there's a data pointer required, grab it and load out the // extra, previously-curried parameters. diff --git a/test/AutoDiff/validation-test/modify_accessor.swift b/test/AutoDiff/validation-test/modify_accessor.swift index dc7576f2770f3..921cfb1283dc8 100644 --- a/test/AutoDiff/validation-test/modify_accessor.swift +++ b/test/AutoDiff/validation-test/modify_accessor.swift @@ -39,5 +39,39 @@ ModifyAccessorTests.test("SimpleModifyAccessor") { expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct)) } +ModifyAccessorTests.test("GenericModifyAccessor") { + struct S: Differentiable { + private var _x : T + + func _endMutation() {} + + var x: T { + get{_x} + set(newValue) { _x = newValue } + _modify { + defer { _endMutation() } + if (x > -x) { + yield &_x + } else { + yield &_x + } + } + } + + init(_ x : T) { + self._x = x + } + } + + func modify_struct(_ x : Float) -> Float { + var s = S(x) + s.x *= s.x + return s.x + } + + expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct)) +} + + runAllTests()