Skip to content

[AutoDiff] Correctly propagate optional adjoint through switch_enum #74985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2024
Merged

Conversation

asl
Copy link
Contributor

@asl asl commented Jul 5, 2024

Fixes #74978

@asl asl requested a review from rxwei July 5, 2024 08:18
@asl asl requested a review from eeckstein as a code owner July 5, 2024 08:18
@asl
Copy link
Contributor Author

asl commented Jul 5, 2024

@swift-ci please test

@asl
Copy link
Contributor Author

asl commented Jul 5, 2024

So, the code in question looks like this:

// square2(y:)
sil hidden [ossa] @$s3opt7square21yS2dSg_tF : $@convention(thin) (Optional<Double>) -> Double {
// %0 "y"                                         // users: %14, %3, %1
bb0(%0 : $Optional<Double>):
  debug_value %0 : $Optional<Double>, let, name "y", argno 1 // id: %1
  %2 = metatype $@thin Double.Type                // user: %26
  switch_enum %0 : $Optional<Double>, case #Optional.some!enumelt: bb2, case #Optional.none!enumelt: bb1 // id: %3
...
// %13                                            // user: %26
bb2(%13 : $Double):                               // Preds: bb0
  switch_enum %0 : $Optional<Double>, case #Optional.some!enumelt: bb4, case #Optional.none!enumelt: bb3 // id: %14
...
// %24                                            // user: %26
bb4(%24 : $Double):                               // Preds: bb2
  // function_ref static Double.* infix(_:_:)
  %25 = function_ref @$sSd1moiyS2d_SdtFZ : $@convention(method) (Double, Double, @thin Double.Type) -> Double // user: %26
  %26 = apply %25(%13, %24, %2) : $@convention(method) (Double, Double, @thin Double.Type) -> Double // user: %27
  return %26 : $Double                            // id: %27
} // end sil function '$s3opt7square21yS2dSg_tF'

However, instead of adjoint accumulation (for %0) we just created new adjoint each time. Funny enough, adjoint buffer version for optionals of non-loadable types was correct – adjoint buffers were accumulated together.

Likely the non-optional switch_enum case is broken as well, I just do not have a testcase that directly exposes the problem.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

@asl
Copy link
Contributor Author

asl commented Jul 5, 2024

@swift-ci please test

@asl
Copy link
Contributor Author

asl commented Jul 11, 2024

@swift-ci please test

@asl asl enabled auto-merge (squash) July 11, 2024 22:16
@asl asl merged commit 9d63e0f into main Jul 12, 2024
5 checks passed
@asl asl deleted the 74978-fix branch July 12, 2024 04:48
@asl asl added the AutoDiff label Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect gradients in some optional-related cases
2 participants