Skip to content

[AutoDiff] Support curry thunks differentiation in fragile funcs #77615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 17, 2025

Conversation

kovdan01
Copy link
Contributor

Inside fragile functions, we expect function derivatives to be public, which could be achieved by either explicitly marking the functions as differentiable or having a public explicit derivative defined for them. This is obviously not possible for single and double curry thunks which are a special case of AutoClosureExpr.

Instead of looking at the thunk itself, we unwrap it and look at the function being wrapped. While the thunk itself and its differentiability witness will not have public visibility, it's not an issue for the case where the function being wrapped (and its witness) have public visibility.

Fixes #54819
Fixes #75776

Inside fragile functions, we expect function derivatives to be public, which
could be achieved by either explicitly marking the functions as differentiable
or having a public explicit derivative defined for them. This is obviously not
possible for single and double curry thunks which are a special case of
`AutoClosureExpr`.

Instead of looking at the thunk itself, we unwrap it and look at the
function being wrapped. While the thunk itself and its differentiability
witness will not have public visibility, it's not an issue for the case
where the function being wrapped (and its witness) have public visibility.

Fixes swiftlang#54819
Fixes swiftlang#75776
@kovdan01
Copy link
Contributor Author

Tagging @asl

@kovdan01 kovdan01 marked this pull request as ready for review November 14, 2024 16:07
@asl
Copy link
Contributor

asl commented Nov 14, 2024

@swift-ci please test

lib/AST/Expr.cpp Outdated
@@ -2211,6 +2211,20 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
return nullptr;
}

Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need getUnwrappedCurryThunkExpr()? It's not clear to me under what circumstances ae->getFn() will return something other than ae->getCalledValue(/*skipFunctionConversions=*/true);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slavapestov Do you mean that getUnwrappedCurryThunkCalledValue() might be used instead of getUnwrappedCurryThunkExpr()? I don't think so: I intentionally moved the actual implementation into a separate getUnwrappedCurryThunkImpl() function, so we can either get an Expr or a ValueDecl depending of what we need.

getUnwrappedCurryThunkExpr() seems to have a number of uses over the code base, and that uses seem to expect Expr and not ValueDecl, so it's probably still needed. Anyway, even if some refactoring might be done (while I do not see a room for it in this particular case), probably it would be better to keep it separate from functional changes.

If I got your question wrong and this does not answer it, it would be great if you can clarify your point a bit.

@asl
Copy link
Contributor

asl commented Nov 14, 2024

Tagging @rxwei

@asl asl requested a review from rxwei November 14, 2024 16:16
@asl
Copy link
Contributor

asl commented Nov 14, 2024

Tagging @JaapWijnen

@kovdan01 kovdan01 requested a review from slavapestov November 14, 2024 16:40
@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

It's already included in test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift
@asl
Copy link
Contributor

asl commented Nov 21, 2024

@swift-ci please test

@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@JaapWijnen
Copy link
Contributor

Great to have this fixed thanks @kovdan01. Looking forward to have this reviewed and merged! :)

@kovdan01
Copy link
Contributor Author

kovdan01 commented Dec 4, 2024

Would be glad to see feedback from everyone interested

3 similar comments
@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

kovdan01 commented Jan 6, 2025

Would be glad to see feedback from everyone interested

4 similar comments
@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

Would be glad to see feedback from everyone interested

@kovdan01
Copy link
Contributor Author

kovdan01 commented Feb 4, 2025

Would be glad to see feedback from everyone interested

@asl asl added the AutoDiff label Feb 13, 2025
auto *abstractCE = originalFn->getDeclRef().getAbstractClosureExpr();
if (abstractCE == nullptr)
return nullptr;
auto *autoCE = dyn_cast<AutoClosureExpr>(abstractCE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could simplify a bit with dyn_cast_or_null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied the suggestion, thanks, see c674f80

// for that due to the crash
// https://github.com/swiftlang/swift/issues/77613
if (currentAFD->hasCurriedSelf())
afdToSILFn.insert({currentAFD, &currentFunc});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try_emplace + assert on result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied the suggestion, thanks, see c674f80

@kovdan01
Copy link
Contributor Author

I've limited scope of the changes only to AutoDiff code - see 2cadfa0

@kovdan01 kovdan01 requested a review from asl February 14, 2025 12:23
@asl
Copy link
Contributor

asl commented Feb 14, 2025

@swift-ci please test

/*skipFunctionConversions=*/true));
break;
}
assert(afd);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer some message attached to assert. I would probably also change switch to ordinary if, this was you can error out in else case and this assert will be unnecessary at all, as afd will always be a result of cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the logic a bit in 4d74d34. It actually looks like that other thunk types (not single and double curry thunks) could go here, and this is valid, we just need to return nullptr (just as we do right above if we detect not AutoClosureExpr).

Since getUnwrappedCurryThunkExpr already has checks against single and double curry thunks and returns nullptr for other thunk types, using it in conjunction with dyn_cast_or_null with subsequent check against nullptr should be enough.

if (currentAFD->hasCurriedSelf()) {
auto [_, wasEmplace] =
afdToSILFn.try_emplace(currentAFD, &currentFunc);
assert(wasEmplace);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4d74d34, thanks


auto silFnIt = afdToSILFn.find(afd);
if (silFnIt == afdToSILFn.end()) {
assert(afdToSILFn.empty());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4d74d34, thanks

@kovdan01 kovdan01 requested a review from asl February 15, 2025 12:05
@asl
Copy link
Contributor

asl commented Feb 17, 2025

@swift-ci please test

Copy link
Contributor

@asl asl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the changes were localized down to autodiff code (no code outside autodiff are to be changed) and looks good to me

@asl asl merged commit 1a42a0c into swiftlang:main Feb 17, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants