Skip to content

[SR-12153] Handle differentiation of casting instructions #53489

@dan-zheng

Description

@dan-zheng
Previous ID SR-12153
Radar None
Original Reporter @dan-zheng
Type Sub-task
Status Closed
Resolution Done
Additional Detail from JIRA
Votes 0
Component/s
Labels Sub-task
Assignee @dan-zheng
Priority Medium

md5: 30381c924b0a7bcac2fa8a47923d6aba

Parent-Task:

  • SR-12148 Make differentiation work with class types

Issue Description:

Casting instructions are generated in subclass initializers, among other places.
They should be handled during pullback generation.

class Super : Differentiable {
  var base: Float

  // TODO(TF-645): Remove `vjpInit` when differentiation supports
  // `ref_element_addr`.
  @differentiable(vjp: vjpInit)
  init(base: Float) { self.base = base }

  static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) {
    return (Super(base: base), { x in x.base })
  }
}

// `Sub.init(base:)` is automatically generated for `Super`.
// In SIL: it uses `upcast` to upcast and `unchecked_ref_cast` to downcast.
class Sub : Super {}

print(pullback(at: 2) { x in Super(base: x) }(Super.TangentVector(base: 1)))
print(pullback(at: 2) { x in Sub(base: x) }(Super.TangentVector(base: 1)))
[AD] Running PullbackEmitter on
// Sub.init(base:)
sil hidden [differentiable source 0 wrt 0 jvp @AD__$s4main3SubC4baseACSf_tcfc__jvp_src_0_wrt_0 vjp @AD__$s4main3SubC4baseACSf_tcfc__vjp_src_0_wrt_0] @$s4main3SubC4baseACSf_tcfc : $@convention(method) (Float, @owned Sub) -> @owned Sub {
// %0                                             // users: %8, %3
// %1                                             // user: %4
bb0(%0 : $Float, %1 : $Sub):
  %2 = alloc_stack $Sub, let, name "self"         // users: %11, %10, %5, %4, %13, %14
  debug_value %0 : $Float, let, name "base", argno 1 // id: %3
  store %1 to %2 : $*Sub                          // id: %4
  %5 = load %2 : $*Sub                            // user: %6
  %6 = upcast %5 : $Sub to $Super                 // user: %8
  // function_ref Super.init(base:)
  %7 = function_ref @$s4main5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // user: %8
  %8 = apply %7(%0, %6) : $@convention(method) (Float, @owned Super) -> @owned Super // user: %9
  %9 = unchecked_ref_cast %8 : $Super to $Sub     // user: %10
  store %9 to %2 : $*Sub                          // id: %10
  %11 = load %2 : $*Sub                           // users: %15, %12
  strong_retain %11 : $Sub                        // id: %12
  destroy_addr %2 : $*Sub                         // id: %13
  dealloc_stack %2 : $*Sub                        // id: %14
  return %11 : $Sub                               // id: %15
} // end sil function '$s4main3SubC4baseACSf_tcfc'
...

[AD] Unhandled instruction in adjoint emitter:   %9 = unchecked_ref_cast %8 : $Super to $Sub     // user: %10
[AD] Diagnosing non-differentiability.
[AD] For instruction:
  %9 = unchecked_ref_cast %8 : $Super to $Sub     // user: %10
[AD] With invoker:
(differentiation_invoker autodiff_function_inst=(  %5 = autodiff_function [wrt 0] [order 1] %3 : $@convention(method) (Float, @owned Sub) -> @owned Sub // users: %8, %6
))
<unknown>:0: error: expression is not differentiable
<unknown>:0: note: expression is not differentiable

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions